1 | #pragma once |
2 | |
3 | #include "taichi/ir/ir.h" |
4 | #include "taichi/ir/offloaded_task_type.h" |
5 | #include "taichi/ir/stmt_op_types.h" |
6 | #include "taichi/rhi/arch.h" |
7 | #include "taichi/ir/mesh.h" |
8 | |
9 | #include <optional> |
10 | |
11 | namespace taichi::lang { |
12 | |
13 | class Function; |
14 | |
15 | /** |
16 | * Allocate a local variable with initial value 0. |
17 | */ |
18 | class AllocaStmt : public Stmt, public ir_traits::Store { |
19 | public: |
20 | explicit AllocaStmt(DataType type) : is_shared(false) { |
21 | ret_type = type; |
22 | TI_STMT_REG_FIELDS; |
23 | } |
24 | |
25 | AllocaStmt(const std::vector<int> &shape, |
26 | DataType type, |
27 | bool is_shared = false) |
28 | : is_shared(is_shared) { |
29 | ret_type = TypeFactory::create_tensor_type(shape, type); |
30 | TI_STMT_REG_FIELDS; |
31 | } |
32 | |
33 | bool has_global_side_effect() const override { |
34 | return false; |
35 | } |
36 | |
37 | bool common_statement_eliminable() const override { |
38 | return false; |
39 | } |
40 | |
41 | // IR Trait: Store |
42 | stmt_refs get_store_destination() const override { |
43 | // The statement itself provides a data source (const [0]). |
44 | return ret_type->is<TensorType>() ? nullptr : (Stmt *)this; |
45 | } |
46 | |
47 | Stmt *get_store_data() const override { |
48 | // For convenience, return store_stmt instead of the const [0] it actually |
49 | // stores. |
50 | return ret_type->is<TensorType>() ? nullptr : (Stmt *)this; |
51 | } |
52 | |
53 | bool is_shared; |
54 | TI_STMT_DEF_FIELDS(ret_type, is_shared); |
55 | TI_DEFINE_ACCEPT_AND_CLONE |
56 | }; |
57 | |
58 | /** |
59 | * Updates mask, break if all bits of the mask are 0. |
60 | */ |
61 | class WhileControlStmt : public Stmt { |
62 | public: |
63 | Stmt *mask; |
64 | Stmt *cond; |
65 | WhileControlStmt(Stmt *mask, Stmt *cond) : mask(mask), cond(cond) { |
66 | TI_STMT_REG_FIELDS; |
67 | } |
68 | |
69 | TI_STMT_DEF_FIELDS(mask, cond); |
70 | TI_DEFINE_ACCEPT_AND_CLONE; |
71 | }; |
72 | |
73 | /** |
74 | * Jump to the next loop iteration, i.e., `continue` in C++. |
75 | */ |
76 | class ContinueStmt : public Stmt { |
77 | public: |
78 | // This is the loop on which this continue stmt has effects. It can be either |
79 | // an offloaded task, or a for/while loop inside the kernel. |
80 | Stmt *scope; |
81 | |
82 | ContinueStmt() : scope(nullptr) { |
83 | TI_STMT_REG_FIELDS; |
84 | } |
85 | |
86 | // For top-level loops, since they are parallelized to multiple threads (on |
87 | // either CPU or GPU), `continue` becomes semantically equivalent to `return`. |
88 | // |
89 | // Caveat: |
90 | // We should wrap each backend's kernel body into a function (as LLVM does). |
91 | // The reason is that, each thread may handle more than one element, |
92 | // depending on the backend's implementation. |
93 | // |
94 | // For example, CUDA uses grid-stride loops, the snippet below illustrates |
95 | // the idea: |
96 | // |
97 | // __global__ foo_kernel(...) { |
98 | // for (int i = lower; i < upper; i += gridDim) { |
99 | // auto coord = compute_coords(i); |
100 | // // run_foo_kernel is produced by codegen |
101 | // run_foo_kernel(coord); |
102 | // } |
103 | // } |
104 | // |
105 | // If run_foo_kernel() is directly inlined within foo_kernel(), `return` |
106 | // could prematurely terminate the entire kernel. |
107 | |
108 | TI_STMT_DEF_FIELDS(scope); |
109 | TI_DEFINE_ACCEPT_AND_CLONE; |
110 | }; |
111 | |
112 | /** |
113 | * A decoration statement. The decorated "operands" will keep this decoration. |
114 | */ |
115 | class DecorationStmt : public Stmt { |
116 | public: |
117 | enum class Decoration : uint32_t { kUnknown, kLoopUnique }; |
118 | |
119 | Stmt *operand; |
120 | std::vector<uint32_t> decoration; |
121 | |
122 | DecorationStmt(Stmt *operand, const std::vector<uint32_t> &decoration); |
123 | |
124 | bool same_operation(DecorationStmt *o) const { |
125 | return false; |
126 | } |
127 | |
128 | bool is_cast() const { |
129 | return false; |
130 | } |
131 | |
132 | bool has_global_side_effect() const override { |
133 | return false; |
134 | } |
135 | |
136 | bool dead_instruction_eliminable() const override { |
137 | return false; |
138 | } |
139 | |
140 | TI_STMT_DEF_FIELDS(operand, decoration); |
141 | TI_DEFINE_ACCEPT_AND_CLONE |
142 | }; |
143 | |
144 | /** |
145 | * A unary operation. The field |cast_type| is used only when is_cast() is true. |
146 | */ |
147 | class UnaryOpStmt : public Stmt { |
148 | public: |
149 | UnaryOpType op_type; |
150 | Stmt *operand; |
151 | DataType cast_type; |
152 | |
153 | UnaryOpStmt(UnaryOpType op_type, Stmt *operand); |
154 | |
155 | bool same_operation(UnaryOpStmt *o) const; |
156 | bool is_cast() const; |
157 | |
158 | bool has_global_side_effect() const override { |
159 | return false; |
160 | } |
161 | |
162 | TI_STMT_DEF_FIELDS(ret_type, op_type, operand, cast_type); |
163 | TI_DEFINE_ACCEPT_AND_CLONE |
164 | }; |
165 | |
166 | /** |
167 | * Load a kernel argument. The data type should be known when constructing this |
168 | * statement. |is_ptr| should be true iff the result can be used as a base |
169 | * pointer of an ExternalPtrStmt. |
170 | */ |
171 | class ArgLoadStmt : public Stmt { |
172 | public: |
173 | int arg_id; |
174 | |
175 | /* TODO(zhanlue): more organized argument-type information |
176 | |
177 | ArgLoadStmt is able to load everything passed into the kernel, |
178 | including but not limited to: scalar, matrix, snode_tree_types(WIP), |
179 | ndarray, ... |
180 | |
181 | Therefore we need to add a field to indicate the type of the argument. For |
182 | now, only "is_ptr" and "field_dims" is needed. |
183 | |
184 | */ |
185 | bool is_ptr; |
186 | |
187 | bool is_grad; |
188 | |
189 | // field_dims of ndarray |
190 | int field_dims_ = 0; |
191 | |
192 | ArgLoadStmt(int arg_id, |
193 | const DataType &dt, |
194 | bool is_ptr = false, |
195 | bool is_grad = false) |
196 | : arg_id(arg_id) { |
197 | this->ret_type = dt; |
198 | this->is_ptr = is_ptr; |
199 | this->is_grad = is_grad; |
200 | this->field_dims_ = -1; // -1 means uninitialized |
201 | TI_STMT_REG_FIELDS; |
202 | } |
203 | |
204 | void set_extern_dims(int dims) { |
205 | this->field_dims_ = dims; |
206 | } |
207 | |
208 | bool has_global_side_effect() const override { |
209 | return false; |
210 | } |
211 | |
212 | TI_STMT_DEF_FIELDS(ret_type, arg_id, is_ptr); |
213 | TI_DEFINE_ACCEPT_AND_CLONE |
214 | }; |
215 | |
216 | /** |
217 | * A random value. For i32, u32, i64, and u64, the result is randomly sampled |
218 | * from all possible values with equal probability. For f32 and f64 data types, |
219 | * the result is uniformly sampled in the interval [0, 1). |
220 | * When Taichi runtime initializes, each CUDA thread / CPU thread gets a |
221 | * different (but deterministic as long as the thread id doesn't change) |
222 | * random seed. Each invocation of a RandStmt compiles to a call of a |
223 | * deterministic PRNG to generate a random value in the backend. |
224 | */ |
225 | class RandStmt : public Stmt { |
226 | public: |
227 | explicit RandStmt(const DataType &dt) { |
228 | ret_type = dt; |
229 | TI_STMT_REG_FIELDS; |
230 | } |
231 | |
232 | bool has_global_side_effect() const override { |
233 | return false; |
234 | } |
235 | |
236 | bool common_statement_eliminable() const override { |
237 | return false; |
238 | } |
239 | |
240 | TI_STMT_DEF_FIELDS(ret_type); |
241 | TI_DEFINE_ACCEPT_AND_CLONE |
242 | }; |
243 | |
244 | /** |
245 | * A binary operation. |
246 | */ |
247 | class BinaryOpStmt : public Stmt { |
248 | public: |
249 | BinaryOpType op_type; |
250 | Stmt *lhs, *rhs; |
251 | bool is_bit_vectorized; // TODO: remove this field |
252 | |
253 | BinaryOpStmt(BinaryOpType op_type, |
254 | Stmt *lhs, |
255 | Stmt *rhs, |
256 | bool is_bit_vectorized = false) |
257 | : op_type(op_type), |
258 | lhs(lhs), |
259 | rhs(rhs), |
260 | is_bit_vectorized(is_bit_vectorized) { |
261 | TI_ASSERT(!lhs->is<AllocaStmt>()); |
262 | TI_ASSERT(!rhs->is<AllocaStmt>()); |
263 | TI_STMT_REG_FIELDS; |
264 | } |
265 | |
266 | bool has_global_side_effect() const override { |
267 | return false; |
268 | } |
269 | |
270 | TI_STMT_DEF_FIELDS(ret_type, op_type, lhs, rhs, is_bit_vectorized); |
271 | TI_DEFINE_ACCEPT_AND_CLONE |
272 | }; |
273 | |
274 | /** |
275 | * A ternary operation. Currently "select" (the ternary conditional operator, |
276 | * "?:" in C++) is the only supported ternary operation. |
277 | */ |
278 | class TernaryOpStmt : public Stmt { |
279 | public: |
280 | TernaryOpType op_type; |
281 | Stmt *op1, *op2, *op3; |
282 | |
283 | TernaryOpStmt(TernaryOpType op_type, Stmt *op1, Stmt *op2, Stmt *op3) |
284 | : op_type(op_type), op1(op1), op2(op2), op3(op3) { |
285 | TI_ASSERT(!op1->is<AllocaStmt>()); |
286 | TI_ASSERT(!op2->is<AllocaStmt>()); |
287 | TI_ASSERT(!op3->is<AllocaStmt>()); |
288 | TI_STMT_REG_FIELDS; |
289 | } |
290 | |
291 | bool has_global_side_effect() const override { |
292 | return false; |
293 | } |
294 | |
295 | TI_STMT_DEF_FIELDS(ret_type, op1, op2, op3); |
296 | TI_DEFINE_ACCEPT_AND_CLONE |
297 | }; |
298 | |
299 | /** |
300 | * An atomic operation. |
301 | */ |
302 | class AtomicOpStmt : public Stmt, |
303 | public ir_traits::Store, |
304 | public ir_traits::Load { |
305 | public: |
306 | AtomicOpType op_type; |
307 | Stmt *dest, *val; |
308 | bool is_reduction; |
309 | |
310 | AtomicOpStmt(AtomicOpType op_type, Stmt *dest, Stmt *val) |
311 | : op_type(op_type), dest(dest), val(val), is_reduction(false) { |
312 | TI_STMT_REG_FIELDS; |
313 | } |
314 | |
315 | static std::unique_ptr<AtomicOpStmt> make_for_reduction(AtomicOpType op_type, |
316 | Stmt *dest, |
317 | Stmt *val) { |
318 | auto stmt = std::make_unique<AtomicOpStmt>(op_type, dest, val); |
319 | stmt->is_reduction = true; |
320 | return stmt; |
321 | } |
322 | |
323 | // IR Trait: Store |
324 | stmt_refs get_store_destination() const override { |
325 | return dest; |
326 | } |
327 | |
328 | Stmt *get_store_data() const override { |
329 | return nullptr; |
330 | } |
331 | |
332 | // IR Trait: Load |
333 | stmt_refs get_load_pointers() const override { |
334 | return dest; |
335 | } |
336 | |
337 | TI_STMT_DEF_FIELDS(ret_type, op_type, dest, val); |
338 | TI_DEFINE_ACCEPT_AND_CLONE |
339 | }; |
340 | |
341 | /** |
342 | * An external pointer. |base_ptr| should be ArgLoadStmt with |
343 | * |is_ptr| == true. |
344 | */ |
345 | class ExternalPtrStmt : public Stmt { |
346 | public: |
347 | Stmt *base_ptr; |
348 | std::vector<Stmt *> indices; |
349 | std::vector<int> element_shape; |
350 | // AOS: element_dim < 0 |
351 | // SOA: element_dim > 0 |
352 | int element_dim; |
353 | |
354 | ExternalPtrStmt(Stmt *base_ptr, const std::vector<Stmt *> &indices); |
355 | |
356 | ExternalPtrStmt(Stmt *base_ptr, |
357 | const std::vector<Stmt *> &indices, |
358 | const std::vector<int> &element_shape, |
359 | int element_dim); |
360 | |
361 | bool has_global_side_effect() const override { |
362 | return false; |
363 | } |
364 | |
365 | TI_STMT_DEF_FIELDS(ret_type, base_ptr, indices); |
366 | TI_DEFINE_ACCEPT_AND_CLONE |
367 | }; |
368 | |
369 | /** |
370 | * A global pointer, currently only able to represent an address in a SNode. |
371 | * When |activate| is true, this statement activates the address it points to, |
372 | * so it has "global side effect" in this case. |
373 | * After the "lower_access" pass, all GlobalPtrStmts should be lowered into |
374 | * SNodeLookupStmts and GetChStmts, and should not appear in the final lowered |
375 | * IR. |
376 | */ |
377 | class GlobalPtrStmt : public Stmt { |
378 | public: |
379 | SNode *snode; |
380 | std::vector<Stmt *> indices; |
381 | bool activate; |
382 | bool is_cell_access; |
383 | bool is_bit_vectorized; // for bit_loop_vectorize pass |
384 | |
385 | GlobalPtrStmt(SNode *snode, |
386 | const std::vector<Stmt *> &indices, |
387 | bool activate = true, |
388 | bool is_cell_access = false); |
389 | |
390 | bool has_global_side_effect() const override { |
391 | return activate; |
392 | } |
393 | |
394 | bool common_statement_eliminable() const override { |
395 | return true; |
396 | } |
397 | |
398 | TI_STMT_DEF_FIELDS(ret_type, snode, indices, activate, is_bit_vectorized); |
399 | TI_DEFINE_ACCEPT_AND_CLONE |
400 | }; |
401 | |
402 | /** |
403 | * An "abstract" pointer for an element of a MatrixField, which logically |
404 | * contains a matrix of GlobalPtrStmts. Upon construction, only snodes, indices, |
405 | * dynamic_indexable, dynamic_index_stride and activate are initialized. After |
406 | * the lower_matrix_ptr pass, this stmt will either be eliminated (constant |
407 | * index) or have ptr_base initialized (dynamic index or whole-matrix access). |
408 | */ |
409 | class MatrixOfGlobalPtrStmt : public Stmt { |
410 | public: |
411 | std::vector<SNode *> snodes; |
412 | std::vector<Stmt *> indices; |
413 | Stmt *ptr_base{nullptr}; |
414 | bool dynamic_indexable{false}; |
415 | int dynamic_index_stride{0}; |
416 | bool activate{true}; |
417 | |
418 | MatrixOfGlobalPtrStmt(const std::vector<SNode *> &snodes, |
419 | const std::vector<Stmt *> &indices, |
420 | bool dynamic_indexable, |
421 | int dynamic_index_stride, |
422 | DataType dt, |
423 | bool activate = true); |
424 | |
425 | bool has_global_side_effect() const override { |
426 | return activate; |
427 | } |
428 | |
429 | bool common_statement_eliminable() const override { |
430 | return true; |
431 | } |
432 | |
433 | TI_STMT_DEF_FIELDS(ret_type, |
434 | snodes, |
435 | indices, |
436 | ptr_base, |
437 | dynamic_indexable, |
438 | dynamic_index_stride, |
439 | activate); |
440 | TI_DEFINE_ACCEPT_AND_CLONE |
441 | }; |
442 | |
443 | /** |
444 | * A matrix of MatrixPtrStmts. The purpose of this stmt is to handle matrix |
445 | * slice and vector swizzle. This stmt will be eliminated after the |
446 | * lower_matrix_ptr pass. |
447 | * |
448 | * TODO(yi/zhanlue): Keep scalarization pass alive for MatrixOfMatrixPtrStmt |
449 | * operations even with real_matrix_scalarize=False |
450 | */ |
451 | class MatrixOfMatrixPtrStmt : public Stmt { |
452 | public: |
453 | std::vector<Stmt *> stmts; |
454 | |
455 | MatrixOfMatrixPtrStmt(const std::vector<Stmt *> &stmts, DataType dt); |
456 | |
457 | TI_STMT_DEF_FIELDS(ret_type, stmts); |
458 | TI_DEFINE_ACCEPT_AND_CLONE |
459 | }; |
460 | |
461 | /** |
462 | * A pointer to an element of a matrix. |
463 | */ |
464 | class MatrixPtrStmt : public Stmt { |
465 | public: |
466 | Stmt *origin{nullptr}; |
467 | Stmt *offset{nullptr}; |
468 | |
469 | MatrixPtrStmt(Stmt *, Stmt *, const std::string & = "" ); |
470 | |
471 | /* TODO(zhanlue/yi): Unify semantics of offset in MatrixPtrStmt |
472 | |
473 | There is a hack in MatrixPtrStmt in terms of the semantics of "offset", |
474 | where "offset" can be interpreted as "number of bytes" or "index" in |
475 | different upper-level code paths |
476 | |
477 | Here we created this offset_used_as_index() function to help indentify |
478 | "offset"'s semantic, but in the end we should unify these two semantics. |
479 | */ |
480 | bool offset_used_as_index() const { |
481 | if (origin->is<AllocaStmt>() || origin->is<GlobalTemporaryStmt>() || |
482 | origin->is<ExternalPtrStmt>()) { |
483 | TI_ASSERT_INFO(origin->ret_type.ptr_removed()->is<TensorType>(), |
484 | "MatrixPtrStmt can only be used for TensorType." ); |
485 | return true; |
486 | } |
487 | return false; |
488 | } |
489 | |
490 | bool is_unlowered_global_ptr() const { |
491 | return origin->is<GlobalPtrStmt>(); |
492 | } |
493 | |
494 | bool has_global_side_effect() const override { |
495 | // After access lowered, activate info will be recorded in SNodeLookupStmt's |
496 | // activate for AOS sparse data structure. We don't support SOA sparse data |
497 | // structure for now. |
498 | return false; |
499 | } |
500 | |
501 | TI_STMT_DEF_FIELDS(ret_type, origin, offset); |
502 | TI_DEFINE_ACCEPT_AND_CLONE |
503 | }; |
504 | |
505 | /** |
506 | * An operation to a SNode (not necessarily a leaf SNode). |
507 | */ |
508 | class SNodeOpStmt : public Stmt, public ir_traits::Store { |
509 | public: |
510 | SNodeOpType op_type; |
511 | SNode *snode; |
512 | Stmt *ptr; |
513 | Stmt *val; |
514 | |
515 | SNodeOpStmt(SNodeOpType op_type, |
516 | SNode *snode, |
517 | Stmt *ptr, |
518 | Stmt *val = nullptr); |
519 | |
520 | static bool activation_related(SNodeOpType op); |
521 | |
522 | static bool need_activation(SNodeOpType op); |
523 | |
524 | // IR Trait: Store |
525 | stmt_refs get_store_destination() const override { |
526 | if (op_type == SNodeOpType::allocate) { |
527 | return std::vector<Stmt *>{val, ptr}; |
528 | } else { |
529 | return nullptr; |
530 | } |
531 | } |
532 | |
533 | Stmt *get_store_data() const override { |
534 | return nullptr; |
535 | } |
536 | |
537 | TI_STMT_DEF_FIELDS(ret_type, op_type, snode, ptr, val); |
538 | TI_DEFINE_ACCEPT_AND_CLONE |
539 | }; |
540 | |
541 | // TODO: remove this |
542 | // (penguinliong) This Stmt is used for both ND-arrays and textures. This is |
543 | // subject to change in the future. |
544 | class ExternalTensorShapeAlongAxisStmt : public Stmt { |
545 | public: |
546 | int axis; |
547 | int arg_id; |
548 | |
549 | ExternalTensorShapeAlongAxisStmt(int axis, int arg_id); |
550 | |
551 | bool has_global_side_effect() const override { |
552 | return false; |
553 | } |
554 | |
555 | TI_STMT_DEF_FIELDS(ret_type, axis, arg_id); |
556 | TI_DEFINE_ACCEPT_AND_CLONE |
557 | }; |
558 | |
559 | /** |
560 | * An assertion. |
561 | * If |cond| is false, print the formatted |text| with |args|, and terminate |
562 | * the program. |
563 | */ |
564 | class AssertStmt : public Stmt { |
565 | public: |
566 | Stmt *cond; |
567 | std::string text; |
568 | std::vector<Stmt *> args; |
569 | |
570 | AssertStmt(Stmt *cond, |
571 | const std::string &text, |
572 | const std::vector<Stmt *> &args) |
573 | : cond(cond), text(text), args(args) { |
574 | TI_ASSERT(cond); |
575 | TI_STMT_REG_FIELDS; |
576 | } |
577 | |
578 | TI_STMT_DEF_FIELDS(cond, text, args); |
579 | TI_DEFINE_ACCEPT_AND_CLONE |
580 | }; |
581 | |
582 | /** |
583 | * Call an external (C++) function. |
584 | */ |
585 | class ExternalFuncCallStmt : public Stmt, |
586 | public ir_traits::Store, |
587 | public ir_traits::Load { |
588 | public: |
589 | enum Type { SHARED_OBJECT = 0, ASSEMBLY = 1, BITCODE = 2 }; |
590 | |
591 | Type type; |
592 | void *so_func; // SHARED_OBJECT |
593 | std::string asm_source; // ASM |
594 | std::string bc_filename; // BITCODE |
595 | std::string bc_funcname; // BITCODE |
596 | std::vector<Stmt *> arg_stmts; |
597 | std::vector<Stmt *> output_stmts; // BITCODE doesn't use this |
598 | |
599 | ExternalFuncCallStmt(Type type, |
600 | void *so_func, |
601 | std::string asm_source, |
602 | std::string bc_filename, |
603 | std::string bc_funcname, |
604 | const std::vector<Stmt *> &arg_stmts, |
605 | const std::vector<Stmt *> &output_stmts) |
606 | : type(type), |
607 | so_func(so_func), |
608 | asm_source(asm_source), |
609 | bc_filename(bc_filename), |
610 | bc_funcname(bc_funcname), |
611 | arg_stmts(arg_stmts), |
612 | output_stmts(output_stmts) { |
613 | TI_STMT_REG_FIELDS; |
614 | } |
615 | |
616 | // IR Trait: Store |
617 | stmt_refs get_store_destination() const override { |
618 | if (type == ExternalFuncCallStmt::BITCODE) { |
619 | return arg_stmts; |
620 | } else { |
621 | return output_stmts; |
622 | } |
623 | } |
624 | |
625 | Stmt *get_store_data() const override { |
626 | return nullptr; |
627 | } |
628 | |
629 | // IR Trait: Load |
630 | stmt_refs get_load_pointers() const override { |
631 | return arg_stmts; |
632 | } |
633 | |
634 | TI_STMT_DEF_FIELDS(type, |
635 | so_func, |
636 | asm_source, |
637 | bc_filename, |
638 | bc_funcname, |
639 | arg_stmts, |
640 | output_stmts); |
641 | TI_DEFINE_ACCEPT_AND_CLONE |
642 | }; |
643 | |
644 | /** |
645 | * A hint to the Taichi compiler about the relation of the values of two |
646 | * statements. |
647 | * This statement simply returns the input statement at the backend, and hints |
648 | * the Taichi compiler that |base| + |low| <= |input| < |base| + |high|. |
649 | */ |
650 | class RangeAssumptionStmt : public Stmt { |
651 | public: |
652 | Stmt *input; |
653 | Stmt *base; |
654 | int low, high; |
655 | |
656 | RangeAssumptionStmt(Stmt *input, Stmt *base, int low, int high) |
657 | : input(input), base(base), low(low), high(high) { |
658 | TI_STMT_REG_FIELDS; |
659 | } |
660 | |
661 | bool has_global_side_effect() const override { |
662 | return false; |
663 | } |
664 | |
665 | TI_STMT_DEF_FIELDS(ret_type, input, base, low, high); |
666 | TI_DEFINE_ACCEPT_AND_CLONE |
667 | }; |
668 | |
669 | /** |
670 | * A hint to the Taichi compiler that a statement has unique values among |
671 | * the top-level loop. This statement simply returns the input statement at |
672 | * the backend, and hints the Taichi compiler that this statement never |
673 | * evaluate to the same value across different iterations of the top-level |
674 | * loop. This statement's value set among all iterations of the top-level loop |
675 | * also covers all active indices of each SNodes with id in the |covers| field |
676 | * of this statement. Since this statement can only evaluate to one value, |
677 | * the SNodes with id in the |covers| field should have only one dimension. |
678 | */ |
679 | class LoopUniqueStmt : public Stmt { |
680 | public: |
681 | Stmt *input; |
682 | std::unordered_set<int> covers; // Stores SNode id |
683 | // std::unordered_set<> provides operator==, and StmtFieldManager will |
684 | // use that to check if two LoopUniqueStmts are the same. |
685 | |
686 | LoopUniqueStmt(Stmt *input, const std::vector<SNode *> &covers); |
687 | |
688 | bool has_global_side_effect() const override { |
689 | return false; |
690 | } |
691 | |
692 | TI_STMT_DEF_FIELDS(ret_type, input, covers); |
693 | TI_DEFINE_ACCEPT_AND_CLONE |
694 | }; |
695 | |
696 | /** |
697 | * A load from a global address, including SNodes, external arrays, TLS, BLS, |
698 | * and global temporary variables. |
699 | */ |
700 | class GlobalLoadStmt : public Stmt, public ir_traits::Load { |
701 | public: |
702 | Stmt *src; |
703 | |
704 | explicit GlobalLoadStmt(Stmt *src) : src(src) { |
705 | TI_STMT_REG_FIELDS; |
706 | } |
707 | |
708 | bool has_global_side_effect() const override { |
709 | return false; |
710 | } |
711 | |
712 | bool common_statement_eliminable() const override { |
713 | return false; |
714 | } |
715 | |
716 | // IR Trait: Load |
717 | stmt_refs get_load_pointers() const override { |
718 | return src; |
719 | } |
720 | |
721 | TI_STMT_DEF_FIELDS(ret_type, src); |
722 | TI_DEFINE_ACCEPT_AND_CLONE; |
723 | }; |
724 | |
725 | /** |
726 | * A store to a global address, including SNodes, external arrays, TLS, BLS, |
727 | * and global temporary variables. |
728 | */ |
729 | class GlobalStoreStmt : public Stmt, public ir_traits::Store { |
730 | public: |
731 | Stmt *dest; |
732 | Stmt *val; |
733 | |
734 | GlobalStoreStmt(Stmt *dest, Stmt *val) : dest(dest), val(val) { |
735 | TI_STMT_REG_FIELDS; |
736 | } |
737 | |
738 | bool common_statement_eliminable() const override { |
739 | return false; |
740 | } |
741 | |
742 | // IR Trait: Store |
743 | stmt_refs get_store_destination() const override { |
744 | return dest; |
745 | } |
746 | |
747 | Stmt *get_store_data() const override { |
748 | return val; |
749 | } |
750 | |
751 | TI_STMT_DEF_FIELDS(ret_type, dest, val); |
752 | TI_DEFINE_ACCEPT_AND_CLONE; |
753 | }; |
754 | |
755 | /** |
756 | * A load from a local variable, i.e., an "alloca". |
757 | */ |
758 | class LocalLoadStmt : public Stmt, public ir_traits::Load { |
759 | public: |
760 | Stmt *src; |
761 | |
762 | explicit LocalLoadStmt(Stmt *src) : src(src) { |
763 | TI_STMT_REG_FIELDS; |
764 | } |
765 | |
766 | bool has_global_side_effect() const override { |
767 | return false; |
768 | } |
769 | |
770 | bool common_statement_eliminable() const override { |
771 | return false; |
772 | } |
773 | |
774 | // IR Trait: Load |
775 | stmt_refs get_load_pointers() const override { |
776 | return src; |
777 | } |
778 | |
779 | TI_STMT_DEF_FIELDS(ret_type, src); |
780 | TI_DEFINE_ACCEPT_AND_CLONE; |
781 | }; |
782 | |
783 | /** |
784 | * A store to a local variable, i.e., an "alloca". |
785 | */ |
786 | class LocalStoreStmt : public Stmt, public ir_traits::Store { |
787 | public: |
788 | Stmt *dest; |
789 | Stmt *val; |
790 | |
791 | LocalStoreStmt(Stmt *dest, Stmt *val) : dest(dest), val(val) { |
792 | TI_ASSERT(dest->is<AllocaStmt>() || dest->is<MatrixPtrStmt>() || |
793 | dest->is<MatrixOfMatrixPtrStmt>()); |
794 | TI_STMT_REG_FIELDS; |
795 | } |
796 | |
797 | bool has_global_side_effect() const override { |
798 | return false; |
799 | } |
800 | |
801 | bool dead_instruction_eliminable() const override { |
802 | return false; |
803 | } |
804 | |
805 | bool common_statement_eliminable() const override { |
806 | return false; |
807 | } |
808 | |
809 | // IR Trait: Store |
810 | stmt_refs get_store_destination() const override { |
811 | return dest; |
812 | } |
813 | |
814 | Stmt *get_store_data() const override { |
815 | return val; |
816 | } |
817 | |
818 | TI_STMT_DEF_FIELDS(ret_type, dest, val); |
819 | TI_DEFINE_ACCEPT_AND_CLONE; |
820 | }; |
821 | |
822 | /** |
823 | * Same as "if (cond) true_statements; else false_statements;" in C++. |
824 | * |true_mask| and |false_mask| are used to support vectorization. |
825 | */ |
826 | class IfStmt : public Stmt { |
827 | public: |
828 | Stmt *cond; |
829 | std::unique_ptr<Block> true_statements, false_statements; |
830 | |
831 | explicit IfStmt(Stmt *cond); |
832 | |
833 | // Use these setters to set Block::parent_stmt at the same time. |
834 | void set_true_statements(std::unique_ptr<Block> &&new_true_statements); |
835 | void set_false_statements(std::unique_ptr<Block> &&new_false_statements); |
836 | |
837 | bool is_container_statement() const override { |
838 | return true; |
839 | } |
840 | |
841 | std::unique_ptr<Stmt> clone() const override; |
842 | |
843 | TI_STMT_DEF_FIELDS(cond); |
844 | TI_DEFINE_ACCEPT |
845 | }; |
846 | |
847 | /** |
848 | * Print the contents in this statement. Each entry in the contents can be |
849 | * either a statement or a string, and they are printed one by one, separated |
850 | * by a comma and a space. |
851 | */ |
852 | class PrintStmt : public Stmt { |
853 | public: |
854 | using EntryType = std::variant<Stmt *, std::string>; |
855 | std::vector<EntryType> contents; |
856 | |
857 | explicit PrintStmt(const std::vector<EntryType> &contents_) |
858 | : contents(contents_) { |
859 | TI_STMT_REG_FIELDS; |
860 | } |
861 | |
862 | template <typename... Args> |
863 | explicit PrintStmt(Stmt *t, Args &&...args) |
864 | : contents(make_entries(t, std::forward<Args>(args)...)) { |
865 | TI_STMT_REG_FIELDS; |
866 | } |
867 | |
868 | template <typename... Args> |
869 | explicit PrintStmt(const std::string &str, Args &&...args) |
870 | : contents(make_entries(str, std::forward<Args>(args)...)) { |
871 | TI_STMT_REG_FIELDS; |
872 | } |
873 | |
874 | TI_STMT_DEF_FIELDS(ret_type, contents); |
875 | TI_DEFINE_ACCEPT_AND_CLONE |
876 | |
877 | private: |
878 | static void make_entries_helper(std::vector<PrintStmt::EntryType> &entries) { |
879 | } |
880 | |
881 | template <typename T, typename... Args> |
882 | static void make_entries_helper(std::vector<PrintStmt::EntryType> &entries, |
883 | T &&t, |
884 | Args &&...values) { |
885 | entries.push_back(EntryType{t}); |
886 | make_entries_helper(entries, std::forward<Args>(values)...); |
887 | } |
888 | |
889 | template <typename... Args> |
890 | static std::vector<EntryType> make_entries(Args &&...values) { |
891 | std::vector<EntryType> ret; |
892 | make_entries_helper(ret, std::forward<Args>(values)...); |
893 | return ret; |
894 | } |
895 | }; |
896 | |
897 | /** |
898 | * A constant value. |
899 | */ |
900 | class ConstStmt : public Stmt { |
901 | public: |
902 | TypedConstant val; |
903 | |
904 | explicit ConstStmt(const TypedConstant &val) : val(val) { |
905 | ret_type = val.dt; |
906 | TI_STMT_REG_FIELDS; |
907 | } |
908 | |
909 | bool has_global_side_effect() const override { |
910 | return false; |
911 | } |
912 | |
913 | TI_STMT_DEF_FIELDS(ret_type, val); |
914 | TI_DEFINE_ACCEPT_AND_CLONE |
915 | }; |
916 | |
917 | /** |
918 | * A general range for, similar to "for (i = begin; i < end; i++) body;" in C++. |
919 | * When |reversed| is true, the for loop is reversed, i.e., |
920 | * "for (i = end - 1; i >= begin; i--) body;". |
921 | * When the statement is in the top level before offloading, it will be |
922 | * offloaded to a parallel for loop. Otherwise, it will be offloaded to a |
923 | * serial for loop. |
924 | */ |
925 | class RangeForStmt : public Stmt { |
926 | public: |
927 | Stmt *begin, *end; |
928 | std::unique_ptr<Block> body; |
929 | bool reversed; |
930 | bool is_bit_vectorized; |
931 | int num_cpu_threads; |
932 | int block_dim; |
933 | bool strictly_serialized; |
934 | std::string range_hint; |
935 | |
936 | RangeForStmt(Stmt *begin, |
937 | Stmt *end, |
938 | std::unique_ptr<Block> &&body, |
939 | bool is_bit_vectorized, |
940 | int num_cpu_threads, |
941 | int block_dim, |
942 | bool strictly_serialized, |
943 | std::string range_hint = "" ); |
944 | |
945 | bool is_container_statement() const override { |
946 | return true; |
947 | } |
948 | |
949 | void reverse() { |
950 | reversed = !reversed; |
951 | } |
952 | |
953 | std::unique_ptr<Stmt> clone() const override; |
954 | |
955 | TI_STMT_DEF_FIELDS(begin, |
956 | end, |
957 | reversed, |
958 | is_bit_vectorized, |
959 | num_cpu_threads, |
960 | block_dim, |
961 | strictly_serialized); |
962 | TI_DEFINE_ACCEPT |
963 | }; |
964 | |
965 | /** |
966 | * A parallel for loop over a SNode, similar to "for i in snode: body" |
967 | * in Python. This statement must be at the top level before offloading. |
968 | */ |
969 | class StructForStmt : public Stmt { |
970 | public: |
971 | SNode *snode; |
972 | std::unique_ptr<Block> body; |
973 | std::unique_ptr<Block> block_initialization; |
974 | std::unique_ptr<Block> block_finalization; |
975 | std::vector<int> index_offsets; |
976 | bool is_bit_vectorized; |
977 | int num_cpu_threads; |
978 | int block_dim; |
979 | MemoryAccessOptions mem_access_opt; |
980 | |
981 | StructForStmt(SNode *snode, |
982 | std::unique_ptr<Block> &&body, |
983 | bool is_bit_vectorized, |
984 | int num_cpu_threads, |
985 | int block_dim); |
986 | |
987 | bool is_container_statement() const override { |
988 | return true; |
989 | } |
990 | |
991 | std::unique_ptr<Stmt> clone() const override; |
992 | |
993 | TI_STMT_DEF_FIELDS(snode, |
994 | index_offsets, |
995 | is_bit_vectorized, |
996 | num_cpu_threads, |
997 | block_dim, |
998 | mem_access_opt); |
999 | TI_DEFINE_ACCEPT |
1000 | }; |
1001 | |
1002 | /** |
1003 | * meshfor |
1004 | */ |
1005 | class MeshForStmt : public Stmt { |
1006 | public: |
1007 | mesh::Mesh *mesh; |
1008 | std::unique_ptr<Block> body; |
1009 | bool is_bit_vectorized; |
1010 | int num_cpu_threads; |
1011 | int block_dim; |
1012 | mesh::MeshElementType major_from_type; |
1013 | std::unordered_set<mesh::MeshElementType> major_to_types{}; |
1014 | std::unordered_set<mesh::MeshRelationType> minor_relation_types{}; |
1015 | MemoryAccessOptions mem_access_opt; |
1016 | |
1017 | MeshForStmt(mesh::Mesh *mesh, |
1018 | mesh::MeshElementType element_type, |
1019 | std::unique_ptr<Block> &&body, |
1020 | bool is_bit_vectorized, |
1021 | int num_cpu_threads, |
1022 | int block_dim); |
1023 | |
1024 | bool is_container_statement() const override { |
1025 | return true; |
1026 | } |
1027 | |
1028 | std::unique_ptr<Stmt> clone() const override; |
1029 | |
1030 | TI_STMT_DEF_FIELDS(mesh, |
1031 | is_bit_vectorized, |
1032 | num_cpu_threads, |
1033 | block_dim, |
1034 | major_from_type, |
1035 | major_to_types, |
1036 | minor_relation_types, |
1037 | mem_access_opt); |
1038 | TI_DEFINE_ACCEPT |
1039 | }; |
1040 | |
1041 | /** |
1042 | * Call an inline Taichi function. |
1043 | */ |
1044 | class FuncCallStmt : public Stmt { |
1045 | public: |
1046 | Function *func; |
1047 | std::vector<Stmt *> args; |
1048 | bool global_side_effect{true}; |
1049 | |
1050 | FuncCallStmt(Function *func, const std::vector<Stmt *> &args); |
1051 | |
1052 | bool has_global_side_effect() const override { |
1053 | return global_side_effect; |
1054 | } |
1055 | |
1056 | TI_STMT_DEF_FIELDS(ret_type, func, args); |
1057 | TI_DEFINE_ACCEPT_AND_CLONE |
1058 | }; |
1059 | |
1060 | /** |
1061 | * A reference to a variable. |
1062 | */ |
1063 | class ReferenceStmt : public Stmt, public ir_traits::Load { |
1064 | public: |
1065 | Stmt *var; |
1066 | bool global_side_effect{false}; |
1067 | |
1068 | explicit ReferenceStmt(Stmt *var) : var(var) { |
1069 | TI_STMT_REG_FIELDS; |
1070 | } |
1071 | |
1072 | bool has_global_side_effect() const override { |
1073 | return global_side_effect; |
1074 | } |
1075 | |
1076 | // IR Trait: Load |
1077 | stmt_refs get_load_pointers() const override { |
1078 | return var; |
1079 | } |
1080 | |
1081 | TI_STMT_DEF_FIELDS(ret_type, var); |
1082 | TI_DEFINE_ACCEPT_AND_CLONE |
1083 | }; |
1084 | |
1085 | /** |
1086 | * Gets an element from a struct |
1087 | */ |
1088 | class GetElementStmt : public Stmt { |
1089 | public: |
1090 | Stmt *src; |
1091 | std::vector<int> index; |
1092 | GetElementStmt(Stmt *src, const std::vector<int> &index) |
1093 | : src(src), index(index) { |
1094 | TI_STMT_REG_FIELDS; |
1095 | } |
1096 | |
1097 | TI_STMT_DEF_FIELDS(ret_type, src, index); |
1098 | TI_DEFINE_ACCEPT_AND_CLONE |
1099 | }; |
1100 | |
1101 | /** |
1102 | * Exit the kernel or function with a return value. |
1103 | */ |
1104 | class ReturnStmt : public Stmt { |
1105 | public: |
1106 | std::vector<Stmt *> values; |
1107 | |
1108 | explicit ReturnStmt(const std::vector<Stmt *> &values) : values(values) { |
1109 | TI_STMT_REG_FIELDS; |
1110 | } |
1111 | |
1112 | explicit ReturnStmt(Stmt *value) : values({value}) { |
1113 | TI_STMT_REG_FIELDS; |
1114 | } |
1115 | |
1116 | std::vector<DataType> element_types() { |
1117 | std::vector<DataType> ele_types; |
1118 | for (auto &x : values) { |
1119 | ele_types.push_back(x->element_type()); |
1120 | } |
1121 | return ele_types; |
1122 | } |
1123 | |
1124 | std::string values_raw_names() { |
1125 | std::string names; |
1126 | for (auto &x : values) { |
1127 | names += x->raw_name() + ", " ; |
1128 | } |
1129 | names.pop_back(); |
1130 | names.pop_back(); |
1131 | return names; |
1132 | } |
1133 | |
1134 | TI_STMT_DEF_FIELDS(values); |
1135 | TI_DEFINE_ACCEPT_AND_CLONE |
1136 | }; |
1137 | |
1138 | /** |
1139 | * A serial while-true loop. |mask| is to support vectorization. |
1140 | */ |
1141 | class WhileStmt : public Stmt { |
1142 | public: |
1143 | Stmt *mask; |
1144 | std::unique_ptr<Block> body; |
1145 | |
1146 | explicit WhileStmt(std::unique_ptr<Block> &&body); |
1147 | |
1148 | bool is_container_statement() const override { |
1149 | return true; |
1150 | } |
1151 | |
1152 | std::unique_ptr<Stmt> clone() const override; |
1153 | |
1154 | TI_STMT_DEF_FIELDS(mask); |
1155 | TI_DEFINE_ACCEPT |
1156 | }; |
1157 | |
1158 | // TODO: remove this (replace with input + ConstStmt(offset)) |
1159 | class IntegerOffsetStmt : public Stmt { |
1160 | public: |
1161 | Stmt *input; |
1162 | int64 offset; |
1163 | |
1164 | IntegerOffsetStmt(Stmt *input, int64 offset) : input(input), offset(offset) { |
1165 | TI_STMT_REG_FIELDS; |
1166 | } |
1167 | |
1168 | bool has_global_side_effect() const override { |
1169 | return false; |
1170 | } |
1171 | |
1172 | TI_STMT_DEF_FIELDS(ret_type, input, offset); |
1173 | TI_DEFINE_ACCEPT_AND_CLONE |
1174 | }; |
1175 | |
1176 | /** |
1177 | * All indices of an address fused together. |
1178 | */ |
1179 | class LinearizeStmt : public Stmt { |
1180 | public: |
1181 | std::vector<Stmt *> inputs; |
1182 | std::vector<int> strides; |
1183 | |
1184 | LinearizeStmt(const std::vector<Stmt *> &inputs, |
1185 | const std::vector<int> &strides) |
1186 | : inputs(inputs), strides(strides) { |
1187 | TI_ASSERT(inputs.size() == strides.size()); |
1188 | TI_STMT_REG_FIELDS; |
1189 | } |
1190 | |
1191 | bool has_global_side_effect() const override { |
1192 | return false; |
1193 | } |
1194 | |
1195 | TI_STMT_DEF_FIELDS(ret_type, inputs, strides); |
1196 | TI_DEFINE_ACCEPT_AND_CLONE |
1197 | }; |
1198 | |
1199 | /** |
1200 | * The SNode root. |
1201 | */ |
1202 | class GetRootStmt : public Stmt { |
1203 | public: |
1204 | explicit GetRootStmt(SNode *root = nullptr) : root_(root) { |
1205 | if (this->root_ != nullptr) { |
1206 | while (this->root_->parent) { |
1207 | this->root_ = this->root_->parent; |
1208 | } |
1209 | } |
1210 | TI_STMT_REG_FIELDS; |
1211 | } |
1212 | |
1213 | bool has_global_side_effect() const override { |
1214 | return false; |
1215 | } |
1216 | |
1217 | TI_STMT_DEF_FIELDS(ret_type, root_); |
1218 | TI_DEFINE_ACCEPT_AND_CLONE |
1219 | |
1220 | SNode *root() { |
1221 | return root_; |
1222 | } |
1223 | |
1224 | const SNode *root() const { |
1225 | return root_; |
1226 | } |
1227 | |
1228 | private: |
1229 | SNode *root_; |
1230 | }; |
1231 | |
1232 | /** |
1233 | * Lookup a component of a SNode. |
1234 | */ |
1235 | class SNodeLookupStmt : public Stmt { |
1236 | public: |
1237 | SNode *snode; |
1238 | Stmt *input_snode; |
1239 | Stmt *input_index; |
1240 | bool activate; |
1241 | |
1242 | SNodeLookupStmt(SNode *snode, |
1243 | Stmt *input_snode, |
1244 | Stmt *input_index, |
1245 | bool activate) |
1246 | : snode(snode), |
1247 | input_snode(input_snode), |
1248 | input_index(input_index), |
1249 | activate(activate) { |
1250 | TI_STMT_REG_FIELDS; |
1251 | } |
1252 | |
1253 | bool has_global_side_effect() const override { |
1254 | return activate; |
1255 | } |
1256 | |
1257 | bool common_statement_eliminable() const override { |
1258 | return true; |
1259 | } |
1260 | |
1261 | TI_STMT_DEF_FIELDS(ret_type, snode, input_snode, input_index, activate); |
1262 | TI_DEFINE_ACCEPT_AND_CLONE |
1263 | }; |
1264 | |
1265 | /** |
1266 | * Get a child of a SNode on the hierarchical SNode tree. |
1267 | */ |
1268 | class GetChStmt : public Stmt { |
1269 | public: |
1270 | Stmt *input_ptr; |
1271 | SNode *input_snode, *output_snode; |
1272 | int chid; |
1273 | bool is_bit_vectorized; |
1274 | |
1275 | GetChStmt(Stmt *input_ptr, int chid, bool is_bit_vectorized = false); |
1276 | GetChStmt(Stmt *input_ptr, |
1277 | SNode *snode, |
1278 | int chid, |
1279 | bool is_bit_vectorized = false); |
1280 | |
1281 | bool has_global_side_effect() const override { |
1282 | return false; |
1283 | } |
1284 | |
1285 | TI_STMT_DEF_FIELDS(ret_type, |
1286 | input_ptr, |
1287 | input_snode, |
1288 | output_snode, |
1289 | chid, |
1290 | is_bit_vectorized); |
1291 | TI_DEFINE_ACCEPT_AND_CLONE |
1292 | }; |
1293 | |
1294 | /** |
1295 | * The statement corresponding to an offloaded task. |
1296 | */ |
1297 | class OffloadedStmt : public Stmt { |
1298 | public: |
1299 | using TaskType = OffloadedTaskType; |
1300 | |
1301 | TaskType task_type; |
1302 | Arch device; |
1303 | SNode *snode{nullptr}; |
1304 | std::size_t begin_offset{0}; |
1305 | std::size_t end_offset{0}; |
1306 | bool const_begin{false}; |
1307 | bool const_end{false}; |
1308 | int32 begin_value{0}; |
1309 | int32 end_value{0}; |
1310 | int grid_dim{1}; |
1311 | int block_dim{1}; |
1312 | bool reversed{false}; |
1313 | bool is_bit_vectorized{false}; |
1314 | int num_cpu_threads{1}; |
1315 | Stmt *end_stmt{nullptr}; |
1316 | std::string range_hint = "" ; |
1317 | |
1318 | mesh::Mesh *mesh{nullptr}; |
1319 | mesh::MeshElementType major_from_type; |
1320 | std::unordered_set<mesh::MeshElementType> major_to_types; |
1321 | std::unordered_set<mesh::MeshRelationType> minor_relation_types; |
1322 | |
1323 | std::unordered_map<mesh::MeshElementType, Stmt *> |
1324 | owned_offset_local; // |owned_offset[idx]| |
1325 | std::unordered_map<mesh::MeshElementType, Stmt *> |
1326 | total_offset_local; // |total_offset[idx]| |
1327 | std::unordered_map<mesh::MeshElementType, Stmt *> |
1328 | owned_num_local; // |owned_offset[idx+1] - owned_offset[idx]| |
1329 | std::unordered_map<mesh::MeshElementType, Stmt *> |
1330 | total_num_local; // |total_offset[idx+1] - total_offset[idx]| |
1331 | |
1332 | std::vector<int> index_offsets; |
1333 | |
1334 | std::unique_ptr<Block> tls_prologue; |
1335 | std::unique_ptr<Block> mesh_prologue; // mesh-for only block |
1336 | std::unique_ptr<Block> bls_prologue; |
1337 | std::unique_ptr<Block> body; |
1338 | std::unique_ptr<Block> bls_epilogue; |
1339 | std::unique_ptr<Block> tls_epilogue; |
1340 | std::size_t tls_size{1}; // avoid allocating dynamic memory with 0 byte |
1341 | std::size_t bls_size{0}; |
1342 | MemoryAccessOptions mem_access_opt; |
1343 | |
1344 | OffloadedStmt(TaskType task_type, Arch arch); |
1345 | |
1346 | std::string task_name() const; |
1347 | |
1348 | static std::string task_type_name(TaskType tt); |
1349 | |
1350 | bool has_body() const { |
1351 | return task_type != TaskType::listgen && task_type != TaskType::gc && |
1352 | task_type != TaskType::gc_rc; |
1353 | } |
1354 | |
1355 | bool is_container_statement() const override { |
1356 | return has_body(); |
1357 | } |
1358 | |
1359 | std::unique_ptr<Stmt> clone() const override; |
1360 | |
1361 | void all_blocks_accept(IRVisitor *visitor, bool skip_mesh_prologue = false); |
1362 | |
1363 | TI_STMT_DEF_FIELDS(ret_type /*inherited from Stmt*/, |
1364 | task_type, |
1365 | device, |
1366 | snode, |
1367 | begin_offset, |
1368 | end_offset, |
1369 | const_begin, |
1370 | const_end, |
1371 | begin_value, |
1372 | end_value, |
1373 | grid_dim, |
1374 | block_dim, |
1375 | reversed, |
1376 | num_cpu_threads, |
1377 | index_offsets, |
1378 | mem_access_opt); |
1379 | TI_DEFINE_ACCEPT |
1380 | }; |
1381 | |
1382 | /** |
1383 | * The |index|-th index of the |loop|. |
1384 | */ |
1385 | class LoopIndexStmt : public Stmt { |
1386 | public: |
1387 | Stmt *loop; |
1388 | int index; |
1389 | |
1390 | LoopIndexStmt(Stmt *loop, int index) : loop(loop), index(index) { |
1391 | TI_STMT_REG_FIELDS; |
1392 | } |
1393 | |
1394 | bool is_mesh_index() const { |
1395 | if (auto offload = loop->cast<OffloadedStmt>()) { |
1396 | return offload->task_type == OffloadedTaskType::mesh_for; |
1397 | } else if (loop->cast<MeshForStmt>()) { |
1398 | return true; |
1399 | } else { |
1400 | return false; |
1401 | } |
1402 | } |
1403 | |
1404 | mesh::MeshElementType mesh_index_type() const { |
1405 | TI_ASSERT(is_mesh_index()); |
1406 | if (auto offload = loop->cast<OffloadedStmt>()) { |
1407 | return offload->major_from_type; |
1408 | } else if (auto mesh_for = loop->cast<MeshForStmt>()) { |
1409 | return mesh_for->major_from_type; |
1410 | } else { |
1411 | TI_NOT_IMPLEMENTED; |
1412 | } |
1413 | } |
1414 | |
1415 | bool has_global_side_effect() const override { |
1416 | return false; |
1417 | } |
1418 | |
1419 | TI_STMT_DEF_FIELDS(ret_type, loop, index); |
1420 | TI_DEFINE_ACCEPT_AND_CLONE |
1421 | }; |
1422 | |
1423 | /** |
1424 | * thread index within a CUDA block |
1425 | * TODO: Remove this. Have a better way for retrieving thread index. |
1426 | */ |
1427 | class LoopLinearIndexStmt : public Stmt { |
1428 | public: |
1429 | Stmt *loop; |
1430 | |
1431 | explicit LoopLinearIndexStmt(Stmt *loop) : loop(loop) { |
1432 | TI_STMT_REG_FIELDS; |
1433 | } |
1434 | |
1435 | bool has_global_side_effect() const override { |
1436 | return false; |
1437 | } |
1438 | |
1439 | TI_STMT_DEF_FIELDS(ret_type, loop); |
1440 | TI_DEFINE_ACCEPT_AND_CLONE |
1441 | }; |
1442 | |
1443 | /** |
1444 | * global thread index, i.e. thread_idx() + block_idx() * block_dim() |
1445 | */ |
1446 | class GlobalThreadIndexStmt : public Stmt { |
1447 | public: |
1448 | explicit GlobalThreadIndexStmt() { |
1449 | TI_STMT_REG_FIELDS; |
1450 | } |
1451 | |
1452 | bool has_global_side_effect() const override { |
1453 | return false; |
1454 | } |
1455 | |
1456 | TI_STMT_DEF_FIELDS(ret_type); |
1457 | TI_DEFINE_ACCEPT_AND_CLONE |
1458 | }; |
1459 | |
1460 | /** |
1461 | * The lowest |index|-th index of the |loop| among the iterations iterated by |
1462 | * the block. |
1463 | */ |
1464 | class BlockCornerIndexStmt : public Stmt { |
1465 | public: |
1466 | Stmt *loop; |
1467 | int index; |
1468 | |
1469 | BlockCornerIndexStmt(Stmt *loop, int index) : loop(loop), index(index) { |
1470 | TI_STMT_REG_FIELDS; |
1471 | } |
1472 | |
1473 | bool has_global_side_effect() const override { |
1474 | return false; |
1475 | } |
1476 | |
1477 | TI_STMT_DEF_FIELDS(ret_type, loop, index); |
1478 | TI_DEFINE_ACCEPT_AND_CLONE |
1479 | }; |
1480 | |
1481 | /** |
1482 | * A global temporary variable, located at |offset| in the global temporary |
1483 | * buffer. |
1484 | */ |
1485 | class GlobalTemporaryStmt : public Stmt { |
1486 | public: |
1487 | std::size_t offset; |
1488 | |
1489 | GlobalTemporaryStmt(std::size_t offset, const DataType &ret_type) |
1490 | : offset(offset) { |
1491 | this->ret_type = ret_type; |
1492 | TI_STMT_REG_FIELDS; |
1493 | } |
1494 | |
1495 | bool has_global_side_effect() const override { |
1496 | return false; |
1497 | } |
1498 | |
1499 | TI_STMT_DEF_FIELDS(ret_type, offset); |
1500 | TI_DEFINE_ACCEPT_AND_CLONE |
1501 | }; |
1502 | |
1503 | /** |
1504 | * A thread-local pointer, located at |offset| in the thread-local storage. |
1505 | */ |
1506 | class ThreadLocalPtrStmt : public Stmt { |
1507 | public: |
1508 | std::size_t offset; |
1509 | |
1510 | ThreadLocalPtrStmt(std::size_t offset, const DataType &ret_type) |
1511 | : offset(offset) { |
1512 | this->ret_type = ret_type; |
1513 | TI_STMT_REG_FIELDS; |
1514 | } |
1515 | |
1516 | bool has_global_side_effect() const override { |
1517 | return false; |
1518 | } |
1519 | |
1520 | TI_STMT_DEF_FIELDS(ret_type, offset); |
1521 | TI_DEFINE_ACCEPT_AND_CLONE |
1522 | }; |
1523 | |
1524 | /** |
1525 | * A block-local pointer, located at |offset| in the block-local storage. |
1526 | */ |
1527 | class BlockLocalPtrStmt : public Stmt { |
1528 | public: |
1529 | Stmt *offset; |
1530 | |
1531 | BlockLocalPtrStmt(Stmt *offset, const DataType &ret_type) : offset(offset) { |
1532 | this->ret_type = ret_type; |
1533 | TI_STMT_REG_FIELDS; |
1534 | } |
1535 | |
1536 | bool has_global_side_effect() const override { |
1537 | return false; |
1538 | } |
1539 | |
1540 | TI_STMT_DEF_FIELDS(ret_type, offset); |
1541 | TI_DEFINE_ACCEPT_AND_CLONE |
1542 | }; |
1543 | |
1544 | /** |
1545 | * The statement corresponding to a clear-list task. |
1546 | */ |
1547 | class ClearListStmt : public Stmt { |
1548 | public: |
1549 | explicit ClearListStmt(SNode *snode); |
1550 | |
1551 | SNode *snode; |
1552 | |
1553 | TI_STMT_DEF_FIELDS(ret_type, snode); |
1554 | TI_DEFINE_ACCEPT_AND_CLONE |
1555 | }; |
1556 | |
1557 | // Checks if the task represented by |stmt| contains a single ClearListStmt. |
1558 | bool is_clear_list_task(const OffloadedStmt *stmt); |
1559 | |
1560 | class InternalFuncStmt : public Stmt { |
1561 | public: |
1562 | std::string func_name; |
1563 | std::vector<Stmt *> args; |
1564 | bool with_runtime_context; |
1565 | |
1566 | explicit InternalFuncStmt(const std::string &func_name, |
1567 | const std::vector<Stmt *> &args, |
1568 | Type *ret_type = nullptr, |
1569 | bool with_runtime_context = true) |
1570 | : func_name(func_name), |
1571 | args(args), |
1572 | with_runtime_context(with_runtime_context) { |
1573 | if (ret_type == nullptr) { |
1574 | this->ret_type = PrimitiveType::i32; |
1575 | } else { |
1576 | this->ret_type = ret_type; |
1577 | } |
1578 | TI_STMT_REG_FIELDS; |
1579 | } |
1580 | |
1581 | TI_STMT_DEF_FIELDS(ret_type, func_name, args, with_runtime_context); |
1582 | TI_DEFINE_ACCEPT_AND_CLONE |
1583 | }; |
1584 | |
1585 | class Texture; |
1586 | |
1587 | class TexturePtrStmt : public Stmt { |
1588 | public: |
1589 | Stmt *arg_load_stmt{nullptr}; |
1590 | int dimensions{2}; |
1591 | bool is_storage{false}; |
1592 | |
1593 | // Optional, for storage textures |
1594 | int num_channels{0}; |
1595 | DataType channel_format{PrimitiveType::f32}; |
1596 | int lod{0}; |
1597 | |
1598 | explicit TexturePtrStmt(Stmt *stmt, |
1599 | int dimensions, |
1600 | bool is_storage, |
1601 | int num_channels, |
1602 | DataType channel_format, |
1603 | int lod) |
1604 | : arg_load_stmt(stmt), |
1605 | dimensions(dimensions), |
1606 | is_storage(is_storage), |
1607 | num_channels(num_channels), |
1608 | channel_format(channel_format), |
1609 | lod(lod) { |
1610 | TI_STMT_REG_FIELDS; |
1611 | } |
1612 | |
1613 | explicit TexturePtrStmt(Stmt *stmt, int dimensions) |
1614 | : arg_load_stmt(stmt), dimensions(dimensions), is_storage(false) { |
1615 | TI_STMT_REG_FIELDS; |
1616 | } |
1617 | |
1618 | TI_STMT_DEF_FIELDS(arg_load_stmt, |
1619 | dimensions, |
1620 | is_storage, |
1621 | num_channels, |
1622 | channel_format, |
1623 | lod); |
1624 | TI_DEFINE_ACCEPT_AND_CLONE |
1625 | }; |
1626 | |
1627 | class TextureOpStmt : public Stmt { |
1628 | public: |
1629 | TextureOpType op; |
1630 | Stmt *texture_ptr; |
1631 | std::vector<Stmt *> args; |
1632 | |
1633 | explicit TextureOpStmt(TextureOpType op, |
1634 | Stmt *texture_ptr, |
1635 | const std::vector<Stmt *> &args) |
1636 | : op(op), texture_ptr(texture_ptr), args(args) { |
1637 | TI_STMT_REG_FIELDS; |
1638 | } |
1639 | |
1640 | /* |
1641 | bool has_global_side_effect() const override { |
1642 | return op == TextureOpType::kStore; |
1643 | } |
1644 | */ |
1645 | |
1646 | bool common_statement_eliminable() const override { |
1647 | return op != TextureOpType::kStore; |
1648 | } |
1649 | |
1650 | TI_STMT_DEF_FIELDS(op, texture_ptr, args); |
1651 | TI_DEFINE_ACCEPT_AND_CLONE |
1652 | }; |
1653 | |
1654 | /** |
1655 | * A local AD-stack. |
1656 | */ |
1657 | class AdStackAllocaStmt : public Stmt { |
1658 | public: |
1659 | DataType dt; |
1660 | std::size_t max_size{0}; // 0 = adaptive |
1661 | |
1662 | AdStackAllocaStmt(const DataType &dt, std::size_t max_size) |
1663 | : dt(dt), max_size(max_size) { |
1664 | TI_STMT_REG_FIELDS; |
1665 | } |
1666 | |
1667 | std::size_t element_size_in_bytes() const { |
1668 | return data_type_size(ret_type); |
1669 | } |
1670 | |
1671 | std::size_t entry_size_in_bytes() const { |
1672 | return element_size_in_bytes() * 2; |
1673 | } |
1674 | |
1675 | std::size_t size_in_bytes() const { |
1676 | return sizeof(int32) + entry_size_in_bytes() * max_size; |
1677 | } |
1678 | |
1679 | bool has_global_side_effect() const override { |
1680 | return false; |
1681 | } |
1682 | |
1683 | bool common_statement_eliminable() const override { |
1684 | return false; |
1685 | } |
1686 | |
1687 | TI_STMT_DEF_FIELDS(ret_type, dt, max_size); |
1688 | TI_DEFINE_ACCEPT_AND_CLONE |
1689 | }; |
1690 | |
1691 | /** |
1692 | * Load the top primal value of an AD-stack. |
1693 | */ |
1694 | class AdStackLoadTopStmt : public Stmt, public ir_traits::Load { |
1695 | public: |
1696 | Stmt *stack; |
1697 | |
1698 | explicit AdStackLoadTopStmt(Stmt *stack) { |
1699 | TI_ASSERT(stack->is<AdStackAllocaStmt>()); |
1700 | this->stack = stack; |
1701 | TI_STMT_REG_FIELDS; |
1702 | } |
1703 | |
1704 | bool has_global_side_effect() const override { |
1705 | return false; |
1706 | } |
1707 | |
1708 | bool common_statement_eliminable() const override { |
1709 | return false; |
1710 | } |
1711 | |
1712 | // IR Trait: Load |
1713 | stmt_refs get_load_pointers() const override { |
1714 | return stack; |
1715 | } |
1716 | |
1717 | TI_STMT_DEF_FIELDS(ret_type, stack); |
1718 | TI_DEFINE_ACCEPT_AND_CLONE |
1719 | }; |
1720 | |
1721 | /** |
1722 | * Load the top adjoint value of an AD-stack. |
1723 | */ |
1724 | class AdStackLoadTopAdjStmt : public Stmt, public ir_traits::Load { |
1725 | public: |
1726 | Stmt *stack; |
1727 | |
1728 | explicit AdStackLoadTopAdjStmt(Stmt *stack) { |
1729 | TI_ASSERT(stack->is<AdStackAllocaStmt>()); |
1730 | this->stack = stack; |
1731 | TI_STMT_REG_FIELDS; |
1732 | } |
1733 | |
1734 | bool has_global_side_effect() const override { |
1735 | return false; |
1736 | } |
1737 | |
1738 | bool common_statement_eliminable() const override { |
1739 | return false; |
1740 | } |
1741 | |
1742 | // IR Trait: Load |
1743 | stmt_refs get_load_pointers() const override { |
1744 | return stack; |
1745 | } |
1746 | |
1747 | TI_STMT_DEF_FIELDS(ret_type, stack); |
1748 | TI_DEFINE_ACCEPT_AND_CLONE |
1749 | }; |
1750 | |
1751 | /** |
1752 | * Pop the top primal and adjoint values in the AD-stack. |
1753 | */ |
1754 | class AdStackPopStmt : public Stmt, public ir_traits::Load { |
1755 | public: |
1756 | Stmt *stack; |
1757 | |
1758 | explicit AdStackPopStmt(Stmt *stack) { |
1759 | TI_ASSERT(stack->is<AdStackAllocaStmt>()); |
1760 | this->stack = stack; |
1761 | TI_STMT_REG_FIELDS; |
1762 | } |
1763 | |
1764 | // IR Trait: Load |
1765 | stmt_refs get_load_pointers() const override { |
1766 | // This is to make dead store elimination not eliminate consequent pops. |
1767 | return stack; |
1768 | } |
1769 | |
1770 | // Mark has_global_side_effect == true to prevent being moved out of an if |
1771 | // clause in the simplify pass for now. |
1772 | |
1773 | TI_STMT_DEF_FIELDS(ret_type, stack); |
1774 | TI_DEFINE_ACCEPT_AND_CLONE |
1775 | }; |
1776 | |
1777 | /** |
1778 | * Push a primal value to the AD-stack, and set the corresponding adjoint |
1779 | * value to 0. |
1780 | */ |
1781 | class AdStackPushStmt : public Stmt, public ir_traits::Load { |
1782 | public: |
1783 | Stmt *stack; |
1784 | Stmt *v; |
1785 | |
1786 | AdStackPushStmt(Stmt *stack, Stmt *v) { |
1787 | TI_ASSERT(stack->is<AdStackAllocaStmt>()); |
1788 | this->stack = stack; |
1789 | this->v = v; |
1790 | TI_STMT_REG_FIELDS; |
1791 | } |
1792 | |
1793 | // IR Trait: Load |
1794 | stmt_refs get_load_pointers() const override { |
1795 | // This is to make dead store elimination not eliminate consequent pushes. |
1796 | return stack; |
1797 | } |
1798 | |
1799 | // Mark has_global_side_effect == true to prevent being moved out of an if |
1800 | // clause in the simplify pass for now. |
1801 | |
1802 | TI_STMT_DEF_FIELDS(ret_type, stack, v); |
1803 | TI_DEFINE_ACCEPT_AND_CLONE |
1804 | }; |
1805 | |
1806 | /** |
1807 | * Accumulate |v| to the top adjoint value of the AD-stack. |
1808 | * This statement loads and stores the adjoint data. |
1809 | */ |
1810 | class AdStackAccAdjointStmt : public Stmt, public ir_traits::Load { |
1811 | public: |
1812 | Stmt *stack; |
1813 | Stmt *v; |
1814 | |
1815 | AdStackAccAdjointStmt(Stmt *stack, Stmt *v) { |
1816 | TI_ASSERT(stack->is<AdStackAllocaStmt>()); |
1817 | this->stack = stack; |
1818 | this->v = v; |
1819 | TI_STMT_REG_FIELDS; |
1820 | } |
1821 | |
1822 | // IR Trait: Load |
1823 | stmt_refs get_load_pointers() const override { |
1824 | return stack; |
1825 | } |
1826 | |
1827 | // Mark has_global_side_effect == true to prevent being moved out of an if |
1828 | // clause in the simplify pass for now. |
1829 | |
1830 | TI_STMT_DEF_FIELDS(ret_type, stack, v); |
1831 | TI_DEFINE_ACCEPT_AND_CLONE |
1832 | }; |
1833 | |
1834 | /** |
1835 | * A global store to one or more children of a bit struct. |
1836 | */ |
1837 | class BitStructStoreStmt : public Stmt { |
1838 | public: |
1839 | Stmt *ptr; |
1840 | std::vector<int> ch_ids; |
1841 | std::vector<Stmt *> values; |
1842 | bool is_atomic; |
1843 | |
1844 | BitStructStoreStmt(Stmt *ptr, |
1845 | const std::vector<int> &ch_ids, |
1846 | const std::vector<Stmt *> &values) |
1847 | : ptr(ptr), ch_ids(ch_ids), values(values), is_atomic(true) { |
1848 | TI_ASSERT(ch_ids.size() == values.size()); |
1849 | TI_STMT_REG_FIELDS; |
1850 | } |
1851 | |
1852 | BitStructType *get_bit_struct() const; |
1853 | |
1854 | bool common_statement_eliminable() const override { |
1855 | return false; |
1856 | } |
1857 | |
1858 | TI_STMT_DEF_FIELDS(ret_type, ptr, ch_ids, values, is_atomic); |
1859 | TI_DEFINE_ACCEPT_AND_CLONE; |
1860 | }; |
1861 | |
1862 | // Mesh related. |
1863 | |
1864 | /** |
1865 | * The relation access, mesh_idx -> to_type[neighbor_idx] |
1866 | * If neibhor_idex has no value, it returns the number of neighbors (length of |
1867 | * relation) of a mesh idx |
1868 | */ |
1869 | class MeshRelationAccessStmt : public Stmt { |
1870 | public: |
1871 | mesh::Mesh *mesh; |
1872 | Stmt *mesh_idx; |
1873 | mesh::MeshElementType to_type; |
1874 | Stmt *neighbor_idx; |
1875 | |
1876 | MeshRelationAccessStmt(mesh::Mesh *mesh, |
1877 | Stmt *mesh_idx, |
1878 | mesh::MeshElementType to_type, |
1879 | Stmt *neighbor_idx) |
1880 | : mesh(mesh), |
1881 | mesh_idx(mesh_idx), |
1882 | to_type(to_type), |
1883 | neighbor_idx(neighbor_idx) { |
1884 | this->ret_type = PrimitiveType::u16; |
1885 | TI_STMT_REG_FIELDS; |
1886 | } |
1887 | |
1888 | MeshRelationAccessStmt(mesh::Mesh *mesh, |
1889 | Stmt *mesh_idx, |
1890 | mesh::MeshElementType to_type) |
1891 | : mesh(mesh), |
1892 | mesh_idx(mesh_idx), |
1893 | to_type(to_type), |
1894 | neighbor_idx(nullptr) { |
1895 | this->ret_type = PrimitiveType::u16; |
1896 | TI_STMT_REG_FIELDS; |
1897 | } |
1898 | |
1899 | bool is_size() const { |
1900 | return neighbor_idx == nullptr; |
1901 | } |
1902 | |
1903 | bool has_global_side_effect() const override { |
1904 | return false; |
1905 | } |
1906 | |
1907 | mesh::MeshElementType from_type() const { |
1908 | if (auto idx = mesh_idx->cast<LoopIndexStmt>()) { |
1909 | TI_ASSERT(idx->is_mesh_index()); |
1910 | return idx->mesh_index_type(); |
1911 | } else if (auto idx = mesh_idx->cast<MeshRelationAccessStmt>()) { |
1912 | TI_ASSERT(!idx->is_size()); |
1913 | return idx->to_type; |
1914 | } else { |
1915 | TI_NOT_IMPLEMENTED; |
1916 | } |
1917 | } |
1918 | |
1919 | TI_STMT_DEF_FIELDS(ret_type, mesh, mesh_idx, to_type, neighbor_idx); |
1920 | TI_DEFINE_ACCEPT_AND_CLONE |
1921 | }; |
1922 | |
1923 | /** |
1924 | * Convert a mesh index to another index space |
1925 | */ |
1926 | class MeshIndexConversionStmt : public Stmt { |
1927 | public: |
1928 | mesh::Mesh *mesh; |
1929 | mesh::MeshElementType idx_type; |
1930 | Stmt *idx; |
1931 | |
1932 | mesh::ConvType conv_type; |
1933 | |
1934 | MeshIndexConversionStmt(mesh::Mesh *mesh, |
1935 | mesh::MeshElementType idx_type, |
1936 | Stmt *idx, |
1937 | mesh::ConvType conv_type) |
1938 | : mesh(mesh), idx_type(idx_type), idx(idx), conv_type(conv_type) { |
1939 | this->ret_type = PrimitiveType::i32; |
1940 | TI_STMT_REG_FIELDS; |
1941 | } |
1942 | |
1943 | bool has_global_side_effect() const override { |
1944 | return false; |
1945 | } |
1946 | |
1947 | TI_STMT_DEF_FIELDS(ret_type, mesh, idx_type, idx, conv_type); |
1948 | TI_DEFINE_ACCEPT_AND_CLONE |
1949 | }; |
1950 | |
1951 | /** |
1952 | * The patch index of the |mesh_loop|. |
1953 | */ |
1954 | class MeshPatchIndexStmt : public Stmt { |
1955 | public: |
1956 | MeshPatchIndexStmt() { |
1957 | this->ret_type = PrimitiveType::i32; |
1958 | TI_STMT_REG_FIELDS; |
1959 | } |
1960 | |
1961 | bool has_global_side_effect() const override { |
1962 | return false; |
1963 | } |
1964 | |
1965 | TI_STMT_DEF_FIELDS(ret_type); |
1966 | TI_DEFINE_ACCEPT_AND_CLONE |
1967 | }; |
1968 | |
1969 | /** |
1970 | * Initialization of a local matrix |
1971 | */ |
1972 | class MatrixInitStmt : public Stmt { |
1973 | public: |
1974 | std::vector<Stmt *> values; |
1975 | |
1976 | explicit MatrixInitStmt(const std::vector<Stmt *> &values) : values(values) { |
1977 | TI_STMT_REG_FIELDS; |
1978 | } |
1979 | |
1980 | bool has_global_side_effect() const override { |
1981 | return false; |
1982 | } |
1983 | |
1984 | TI_STMT_DEF_FIELDS(ret_type, values); |
1985 | TI_DEFINE_ACCEPT_AND_CLONE |
1986 | }; |
1987 | |
1988 | } // namespace taichi::lang |
1989 | |