1#pragma once
2
3#include "taichi/ir/ir.h"
4#include "taichi/ir/offloaded_task_type.h"
5#include "taichi/ir/stmt_op_types.h"
6#include "taichi/rhi/arch.h"
7#include "taichi/ir/mesh.h"
8
9#include <optional>
10
11namespace taichi::lang {
12
13class Function;
14
15/**
16 * Allocate a local variable with initial value 0.
17 */
18class AllocaStmt : public Stmt, public ir_traits::Store {
19 public:
20 explicit AllocaStmt(DataType type) : is_shared(false) {
21 ret_type = type;
22 TI_STMT_REG_FIELDS;
23 }
24
25 AllocaStmt(const std::vector<int> &shape,
26 DataType type,
27 bool is_shared = false)
28 : is_shared(is_shared) {
29 ret_type = TypeFactory::create_tensor_type(shape, type);
30 TI_STMT_REG_FIELDS;
31 }
32
33 bool has_global_side_effect() const override {
34 return false;
35 }
36
37 bool common_statement_eliminable() const override {
38 return false;
39 }
40
41 // IR Trait: Store
42 stmt_refs get_store_destination() const override {
43 // The statement itself provides a data source (const [0]).
44 return ret_type->is<TensorType>() ? nullptr : (Stmt *)this;
45 }
46
47 Stmt *get_store_data() const override {
48 // For convenience, return store_stmt instead of the const [0] it actually
49 // stores.
50 return ret_type->is<TensorType>() ? nullptr : (Stmt *)this;
51 }
52
53 bool is_shared;
54 TI_STMT_DEF_FIELDS(ret_type, is_shared);
55 TI_DEFINE_ACCEPT_AND_CLONE
56};
57
58/**
59 * Updates mask, break if all bits of the mask are 0.
60 */
61class WhileControlStmt : public Stmt {
62 public:
63 Stmt *mask;
64 Stmt *cond;
65 WhileControlStmt(Stmt *mask, Stmt *cond) : mask(mask), cond(cond) {
66 TI_STMT_REG_FIELDS;
67 }
68
69 TI_STMT_DEF_FIELDS(mask, cond);
70 TI_DEFINE_ACCEPT_AND_CLONE;
71};
72
73/**
74 * Jump to the next loop iteration, i.e., `continue` in C++.
75 */
76class ContinueStmt : public Stmt {
77 public:
78 // This is the loop on which this continue stmt has effects. It can be either
79 // an offloaded task, or a for/while loop inside the kernel.
80 Stmt *scope;
81
82 ContinueStmt() : scope(nullptr) {
83 TI_STMT_REG_FIELDS;
84 }
85
86 // For top-level loops, since they are parallelized to multiple threads (on
87 // either CPU or GPU), `continue` becomes semantically equivalent to `return`.
88 //
89 // Caveat:
90 // We should wrap each backend's kernel body into a function (as LLVM does).
91 // The reason is that, each thread may handle more than one element,
92 // depending on the backend's implementation.
93 //
94 // For example, CUDA uses grid-stride loops, the snippet below illustrates
95 // the idea:
96 //
97 // __global__ foo_kernel(...) {
98 // for (int i = lower; i < upper; i += gridDim) {
99 // auto coord = compute_coords(i);
100 // // run_foo_kernel is produced by codegen
101 // run_foo_kernel(coord);
102 // }
103 // }
104 //
105 // If run_foo_kernel() is directly inlined within foo_kernel(), `return`
106 // could prematurely terminate the entire kernel.
107
108 TI_STMT_DEF_FIELDS(scope);
109 TI_DEFINE_ACCEPT_AND_CLONE;
110};
111
112/**
113 * A decoration statement. The decorated "operands" will keep this decoration.
114 */
115class DecorationStmt : public Stmt {
116 public:
117 enum class Decoration : uint32_t { kUnknown, kLoopUnique };
118
119 Stmt *operand;
120 std::vector<uint32_t> decoration;
121
122 DecorationStmt(Stmt *operand, const std::vector<uint32_t> &decoration);
123
124 bool same_operation(DecorationStmt *o) const {
125 return false;
126 }
127
128 bool is_cast() const {
129 return false;
130 }
131
132 bool has_global_side_effect() const override {
133 return false;
134 }
135
136 bool dead_instruction_eliminable() const override {
137 return false;
138 }
139
140 TI_STMT_DEF_FIELDS(operand, decoration);
141 TI_DEFINE_ACCEPT_AND_CLONE
142};
143
144/**
145 * A unary operation. The field |cast_type| is used only when is_cast() is true.
146 */
147class UnaryOpStmt : public Stmt {
148 public:
149 UnaryOpType op_type;
150 Stmt *operand;
151 DataType cast_type;
152
153 UnaryOpStmt(UnaryOpType op_type, Stmt *operand);
154
155 bool same_operation(UnaryOpStmt *o) const;
156 bool is_cast() const;
157
158 bool has_global_side_effect() const override {
159 return false;
160 }
161
162 TI_STMT_DEF_FIELDS(ret_type, op_type, operand, cast_type);
163 TI_DEFINE_ACCEPT_AND_CLONE
164};
165
166/**
167 * Load a kernel argument. The data type should be known when constructing this
168 * statement. |is_ptr| should be true iff the result can be used as a base
169 * pointer of an ExternalPtrStmt.
170 */
171class ArgLoadStmt : public Stmt {
172 public:
173 int arg_id;
174
175 /* TODO(zhanlue): more organized argument-type information
176
177 ArgLoadStmt is able to load everything passed into the kernel,
178 including but not limited to: scalar, matrix, snode_tree_types(WIP),
179 ndarray, ...
180
181 Therefore we need to add a field to indicate the type of the argument. For
182 now, only "is_ptr" and "field_dims" is needed.
183
184 */
185 bool is_ptr;
186
187 bool is_grad;
188
189 // field_dims of ndarray
190 int field_dims_ = 0;
191
192 ArgLoadStmt(int arg_id,
193 const DataType &dt,
194 bool is_ptr = false,
195 bool is_grad = false)
196 : arg_id(arg_id) {
197 this->ret_type = dt;
198 this->is_ptr = is_ptr;
199 this->is_grad = is_grad;
200 this->field_dims_ = -1; // -1 means uninitialized
201 TI_STMT_REG_FIELDS;
202 }
203
204 void set_extern_dims(int dims) {
205 this->field_dims_ = dims;
206 }
207
208 bool has_global_side_effect() const override {
209 return false;
210 }
211
212 TI_STMT_DEF_FIELDS(ret_type, arg_id, is_ptr);
213 TI_DEFINE_ACCEPT_AND_CLONE
214};
215
216/**
217 * A random value. For i32, u32, i64, and u64, the result is randomly sampled
218 * from all possible values with equal probability. For f32 and f64 data types,
219 * the result is uniformly sampled in the interval [0, 1).
220 * When Taichi runtime initializes, each CUDA thread / CPU thread gets a
221 * different (but deterministic as long as the thread id doesn't change)
222 * random seed. Each invocation of a RandStmt compiles to a call of a
223 * deterministic PRNG to generate a random value in the backend.
224 */
225class RandStmt : public Stmt {
226 public:
227 explicit RandStmt(const DataType &dt) {
228 ret_type = dt;
229 TI_STMT_REG_FIELDS;
230 }
231
232 bool has_global_side_effect() const override {
233 return false;
234 }
235
236 bool common_statement_eliminable() const override {
237 return false;
238 }
239
240 TI_STMT_DEF_FIELDS(ret_type);
241 TI_DEFINE_ACCEPT_AND_CLONE
242};
243
244/**
245 * A binary operation.
246 */
247class BinaryOpStmt : public Stmt {
248 public:
249 BinaryOpType op_type;
250 Stmt *lhs, *rhs;
251 bool is_bit_vectorized; // TODO: remove this field
252
253 BinaryOpStmt(BinaryOpType op_type,
254 Stmt *lhs,
255 Stmt *rhs,
256 bool is_bit_vectorized = false)
257 : op_type(op_type),
258 lhs(lhs),
259 rhs(rhs),
260 is_bit_vectorized(is_bit_vectorized) {
261 TI_ASSERT(!lhs->is<AllocaStmt>());
262 TI_ASSERT(!rhs->is<AllocaStmt>());
263 TI_STMT_REG_FIELDS;
264 }
265
266 bool has_global_side_effect() const override {
267 return false;
268 }
269
270 TI_STMT_DEF_FIELDS(ret_type, op_type, lhs, rhs, is_bit_vectorized);
271 TI_DEFINE_ACCEPT_AND_CLONE
272};
273
274/**
275 * A ternary operation. Currently "select" (the ternary conditional operator,
276 * "?:" in C++) is the only supported ternary operation.
277 */
278class TernaryOpStmt : public Stmt {
279 public:
280 TernaryOpType op_type;
281 Stmt *op1, *op2, *op3;
282
283 TernaryOpStmt(TernaryOpType op_type, Stmt *op1, Stmt *op2, Stmt *op3)
284 : op_type(op_type), op1(op1), op2(op2), op3(op3) {
285 TI_ASSERT(!op1->is<AllocaStmt>());
286 TI_ASSERT(!op2->is<AllocaStmt>());
287 TI_ASSERT(!op3->is<AllocaStmt>());
288 TI_STMT_REG_FIELDS;
289 }
290
291 bool has_global_side_effect() const override {
292 return false;
293 }
294
295 TI_STMT_DEF_FIELDS(ret_type, op1, op2, op3);
296 TI_DEFINE_ACCEPT_AND_CLONE
297};
298
299/**
300 * An atomic operation.
301 */
302class AtomicOpStmt : public Stmt,
303 public ir_traits::Store,
304 public ir_traits::Load {
305 public:
306 AtomicOpType op_type;
307 Stmt *dest, *val;
308 bool is_reduction;
309
310 AtomicOpStmt(AtomicOpType op_type, Stmt *dest, Stmt *val)
311 : op_type(op_type), dest(dest), val(val), is_reduction(false) {
312 TI_STMT_REG_FIELDS;
313 }
314
315 static std::unique_ptr<AtomicOpStmt> make_for_reduction(AtomicOpType op_type,
316 Stmt *dest,
317 Stmt *val) {
318 auto stmt = std::make_unique<AtomicOpStmt>(op_type, dest, val);
319 stmt->is_reduction = true;
320 return stmt;
321 }
322
323 // IR Trait: Store
324 stmt_refs get_store_destination() const override {
325 return dest;
326 }
327
328 Stmt *get_store_data() const override {
329 return nullptr;
330 }
331
332 // IR Trait: Load
333 stmt_refs get_load_pointers() const override {
334 return dest;
335 }
336
337 TI_STMT_DEF_FIELDS(ret_type, op_type, dest, val);
338 TI_DEFINE_ACCEPT_AND_CLONE
339};
340
341/**
342 * An external pointer. |base_ptr| should be ArgLoadStmt with
343 * |is_ptr| == true.
344 */
345class ExternalPtrStmt : public Stmt {
346 public:
347 Stmt *base_ptr;
348 std::vector<Stmt *> indices;
349 std::vector<int> element_shape;
350 // AOS: element_dim < 0
351 // SOA: element_dim > 0
352 int element_dim;
353
354 ExternalPtrStmt(Stmt *base_ptr, const std::vector<Stmt *> &indices);
355
356 ExternalPtrStmt(Stmt *base_ptr,
357 const std::vector<Stmt *> &indices,
358 const std::vector<int> &element_shape,
359 int element_dim);
360
361 bool has_global_side_effect() const override {
362 return false;
363 }
364
365 TI_STMT_DEF_FIELDS(ret_type, base_ptr, indices);
366 TI_DEFINE_ACCEPT_AND_CLONE
367};
368
369/**
370 * A global pointer, currently only able to represent an address in a SNode.
371 * When |activate| is true, this statement activates the address it points to,
372 * so it has "global side effect" in this case.
373 * After the "lower_access" pass, all GlobalPtrStmts should be lowered into
374 * SNodeLookupStmts and GetChStmts, and should not appear in the final lowered
375 * IR.
376 */
377class GlobalPtrStmt : public Stmt {
378 public:
379 SNode *snode;
380 std::vector<Stmt *> indices;
381 bool activate;
382 bool is_cell_access;
383 bool is_bit_vectorized; // for bit_loop_vectorize pass
384
385 GlobalPtrStmt(SNode *snode,
386 const std::vector<Stmt *> &indices,
387 bool activate = true,
388 bool is_cell_access = false);
389
390 bool has_global_side_effect() const override {
391 return activate;
392 }
393
394 bool common_statement_eliminable() const override {
395 return true;
396 }
397
398 TI_STMT_DEF_FIELDS(ret_type, snode, indices, activate, is_bit_vectorized);
399 TI_DEFINE_ACCEPT_AND_CLONE
400};
401
402/**
403 * An "abstract" pointer for an element of a MatrixField, which logically
404 * contains a matrix of GlobalPtrStmts. Upon construction, only snodes, indices,
405 * dynamic_indexable, dynamic_index_stride and activate are initialized. After
406 * the lower_matrix_ptr pass, this stmt will either be eliminated (constant
407 * index) or have ptr_base initialized (dynamic index or whole-matrix access).
408 */
409class MatrixOfGlobalPtrStmt : public Stmt {
410 public:
411 std::vector<SNode *> snodes;
412 std::vector<Stmt *> indices;
413 Stmt *ptr_base{nullptr};
414 bool dynamic_indexable{false};
415 int dynamic_index_stride{0};
416 bool activate{true};
417
418 MatrixOfGlobalPtrStmt(const std::vector<SNode *> &snodes,
419 const std::vector<Stmt *> &indices,
420 bool dynamic_indexable,
421 int dynamic_index_stride,
422 DataType dt,
423 bool activate = true);
424
425 bool has_global_side_effect() const override {
426 return activate;
427 }
428
429 bool common_statement_eliminable() const override {
430 return true;
431 }
432
433 TI_STMT_DEF_FIELDS(ret_type,
434 snodes,
435 indices,
436 ptr_base,
437 dynamic_indexable,
438 dynamic_index_stride,
439 activate);
440 TI_DEFINE_ACCEPT_AND_CLONE
441};
442
443/**
444 * A matrix of MatrixPtrStmts. The purpose of this stmt is to handle matrix
445 * slice and vector swizzle. This stmt will be eliminated after the
446 * lower_matrix_ptr pass.
447 *
448 * TODO(yi/zhanlue): Keep scalarization pass alive for MatrixOfMatrixPtrStmt
449 * operations even with real_matrix_scalarize=False
450 */
451class MatrixOfMatrixPtrStmt : public Stmt {
452 public:
453 std::vector<Stmt *> stmts;
454
455 MatrixOfMatrixPtrStmt(const std::vector<Stmt *> &stmts, DataType dt);
456
457 TI_STMT_DEF_FIELDS(ret_type, stmts);
458 TI_DEFINE_ACCEPT_AND_CLONE
459};
460
461/**
462 * A pointer to an element of a matrix.
463 */
464class MatrixPtrStmt : public Stmt {
465 public:
466 Stmt *origin{nullptr};
467 Stmt *offset{nullptr};
468
469 MatrixPtrStmt(Stmt *, Stmt *, const std::string & = "");
470
471 /* TODO(zhanlue/yi): Unify semantics of offset in MatrixPtrStmt
472
473 There is a hack in MatrixPtrStmt in terms of the semantics of "offset",
474 where "offset" can be interpreted as "number of bytes" or "index" in
475 different upper-level code paths
476
477 Here we created this offset_used_as_index() function to help indentify
478 "offset"'s semantic, but in the end we should unify these two semantics.
479 */
480 bool offset_used_as_index() const {
481 if (origin->is<AllocaStmt>() || origin->is<GlobalTemporaryStmt>() ||
482 origin->is<ExternalPtrStmt>()) {
483 TI_ASSERT_INFO(origin->ret_type.ptr_removed()->is<TensorType>(),
484 "MatrixPtrStmt can only be used for TensorType.");
485 return true;
486 }
487 return false;
488 }
489
490 bool is_unlowered_global_ptr() const {
491 return origin->is<GlobalPtrStmt>();
492 }
493
494 bool has_global_side_effect() const override {
495 // After access lowered, activate info will be recorded in SNodeLookupStmt's
496 // activate for AOS sparse data structure. We don't support SOA sparse data
497 // structure for now.
498 return false;
499 }
500
501 TI_STMT_DEF_FIELDS(ret_type, origin, offset);
502 TI_DEFINE_ACCEPT_AND_CLONE
503};
504
505/**
506 * An operation to a SNode (not necessarily a leaf SNode).
507 */
508class SNodeOpStmt : public Stmt, public ir_traits::Store {
509 public:
510 SNodeOpType op_type;
511 SNode *snode;
512 Stmt *ptr;
513 Stmt *val;
514
515 SNodeOpStmt(SNodeOpType op_type,
516 SNode *snode,
517 Stmt *ptr,
518 Stmt *val = nullptr);
519
520 static bool activation_related(SNodeOpType op);
521
522 static bool need_activation(SNodeOpType op);
523
524 // IR Trait: Store
525 stmt_refs get_store_destination() const override {
526 if (op_type == SNodeOpType::allocate) {
527 return std::vector<Stmt *>{val, ptr};
528 } else {
529 return nullptr;
530 }
531 }
532
533 Stmt *get_store_data() const override {
534 return nullptr;
535 }
536
537 TI_STMT_DEF_FIELDS(ret_type, op_type, snode, ptr, val);
538 TI_DEFINE_ACCEPT_AND_CLONE
539};
540
541// TODO: remove this
542// (penguinliong) This Stmt is used for both ND-arrays and textures. This is
543// subject to change in the future.
544class ExternalTensorShapeAlongAxisStmt : public Stmt {
545 public:
546 int axis;
547 int arg_id;
548
549 ExternalTensorShapeAlongAxisStmt(int axis, int arg_id);
550
551 bool has_global_side_effect() const override {
552 return false;
553 }
554
555 TI_STMT_DEF_FIELDS(ret_type, axis, arg_id);
556 TI_DEFINE_ACCEPT_AND_CLONE
557};
558
559/**
560 * An assertion.
561 * If |cond| is false, print the formatted |text| with |args|, and terminate
562 * the program.
563 */
564class AssertStmt : public Stmt {
565 public:
566 Stmt *cond;
567 std::string text;
568 std::vector<Stmt *> args;
569
570 AssertStmt(Stmt *cond,
571 const std::string &text,
572 const std::vector<Stmt *> &args)
573 : cond(cond), text(text), args(args) {
574 TI_ASSERT(cond);
575 TI_STMT_REG_FIELDS;
576 }
577
578 TI_STMT_DEF_FIELDS(cond, text, args);
579 TI_DEFINE_ACCEPT_AND_CLONE
580};
581
582/**
583 * Call an external (C++) function.
584 */
585class ExternalFuncCallStmt : public Stmt,
586 public ir_traits::Store,
587 public ir_traits::Load {
588 public:
589 enum Type { SHARED_OBJECT = 0, ASSEMBLY = 1, BITCODE = 2 };
590
591 Type type;
592 void *so_func; // SHARED_OBJECT
593 std::string asm_source; // ASM
594 std::string bc_filename; // BITCODE
595 std::string bc_funcname; // BITCODE
596 std::vector<Stmt *> arg_stmts;
597 std::vector<Stmt *> output_stmts; // BITCODE doesn't use this
598
599 ExternalFuncCallStmt(Type type,
600 void *so_func,
601 std::string asm_source,
602 std::string bc_filename,
603 std::string bc_funcname,
604 const std::vector<Stmt *> &arg_stmts,
605 const std::vector<Stmt *> &output_stmts)
606 : type(type),
607 so_func(so_func),
608 asm_source(asm_source),
609 bc_filename(bc_filename),
610 bc_funcname(bc_funcname),
611 arg_stmts(arg_stmts),
612 output_stmts(output_stmts) {
613 TI_STMT_REG_FIELDS;
614 }
615
616 // IR Trait: Store
617 stmt_refs get_store_destination() const override {
618 if (type == ExternalFuncCallStmt::BITCODE) {
619 return arg_stmts;
620 } else {
621 return output_stmts;
622 }
623 }
624
625 Stmt *get_store_data() const override {
626 return nullptr;
627 }
628
629 // IR Trait: Load
630 stmt_refs get_load_pointers() const override {
631 return arg_stmts;
632 }
633
634 TI_STMT_DEF_FIELDS(type,
635 so_func,
636 asm_source,
637 bc_filename,
638 bc_funcname,
639 arg_stmts,
640 output_stmts);
641 TI_DEFINE_ACCEPT_AND_CLONE
642};
643
644/**
645 * A hint to the Taichi compiler about the relation of the values of two
646 * statements.
647 * This statement simply returns the input statement at the backend, and hints
648 * the Taichi compiler that |base| + |low| <= |input| < |base| + |high|.
649 */
650class RangeAssumptionStmt : public Stmt {
651 public:
652 Stmt *input;
653 Stmt *base;
654 int low, high;
655
656 RangeAssumptionStmt(Stmt *input, Stmt *base, int low, int high)
657 : input(input), base(base), low(low), high(high) {
658 TI_STMT_REG_FIELDS;
659 }
660
661 bool has_global_side_effect() const override {
662 return false;
663 }
664
665 TI_STMT_DEF_FIELDS(ret_type, input, base, low, high);
666 TI_DEFINE_ACCEPT_AND_CLONE
667};
668
669/**
670 * A hint to the Taichi compiler that a statement has unique values among
671 * the top-level loop. This statement simply returns the input statement at
672 * the backend, and hints the Taichi compiler that this statement never
673 * evaluate to the same value across different iterations of the top-level
674 * loop. This statement's value set among all iterations of the top-level loop
675 * also covers all active indices of each SNodes with id in the |covers| field
676 * of this statement. Since this statement can only evaluate to one value,
677 * the SNodes with id in the |covers| field should have only one dimension.
678 */
679class LoopUniqueStmt : public Stmt {
680 public:
681 Stmt *input;
682 std::unordered_set<int> covers; // Stores SNode id
683 // std::unordered_set<> provides operator==, and StmtFieldManager will
684 // use that to check if two LoopUniqueStmts are the same.
685
686 LoopUniqueStmt(Stmt *input, const std::vector<SNode *> &covers);
687
688 bool has_global_side_effect() const override {
689 return false;
690 }
691
692 TI_STMT_DEF_FIELDS(ret_type, input, covers);
693 TI_DEFINE_ACCEPT_AND_CLONE
694};
695
696/**
697 * A load from a global address, including SNodes, external arrays, TLS, BLS,
698 * and global temporary variables.
699 */
700class GlobalLoadStmt : public Stmt, public ir_traits::Load {
701 public:
702 Stmt *src;
703
704 explicit GlobalLoadStmt(Stmt *src) : src(src) {
705 TI_STMT_REG_FIELDS;
706 }
707
708 bool has_global_side_effect() const override {
709 return false;
710 }
711
712 bool common_statement_eliminable() const override {
713 return false;
714 }
715
716 // IR Trait: Load
717 stmt_refs get_load_pointers() const override {
718 return src;
719 }
720
721 TI_STMT_DEF_FIELDS(ret_type, src);
722 TI_DEFINE_ACCEPT_AND_CLONE;
723};
724
725/**
726 * A store to a global address, including SNodes, external arrays, TLS, BLS,
727 * and global temporary variables.
728 */
729class GlobalStoreStmt : public Stmt, public ir_traits::Store {
730 public:
731 Stmt *dest;
732 Stmt *val;
733
734 GlobalStoreStmt(Stmt *dest, Stmt *val) : dest(dest), val(val) {
735 TI_STMT_REG_FIELDS;
736 }
737
738 bool common_statement_eliminable() const override {
739 return false;
740 }
741
742 // IR Trait: Store
743 stmt_refs get_store_destination() const override {
744 return dest;
745 }
746
747 Stmt *get_store_data() const override {
748 return val;
749 }
750
751 TI_STMT_DEF_FIELDS(ret_type, dest, val);
752 TI_DEFINE_ACCEPT_AND_CLONE;
753};
754
755/**
756 * A load from a local variable, i.e., an "alloca".
757 */
758class LocalLoadStmt : public Stmt, public ir_traits::Load {
759 public:
760 Stmt *src;
761
762 explicit LocalLoadStmt(Stmt *src) : src(src) {
763 TI_STMT_REG_FIELDS;
764 }
765
766 bool has_global_side_effect() const override {
767 return false;
768 }
769
770 bool common_statement_eliminable() const override {
771 return false;
772 }
773
774 // IR Trait: Load
775 stmt_refs get_load_pointers() const override {
776 return src;
777 }
778
779 TI_STMT_DEF_FIELDS(ret_type, src);
780 TI_DEFINE_ACCEPT_AND_CLONE;
781};
782
783/**
784 * A store to a local variable, i.e., an "alloca".
785 */
786class LocalStoreStmt : public Stmt, public ir_traits::Store {
787 public:
788 Stmt *dest;
789 Stmt *val;
790
791 LocalStoreStmt(Stmt *dest, Stmt *val) : dest(dest), val(val) {
792 TI_ASSERT(dest->is<AllocaStmt>() || dest->is<MatrixPtrStmt>() ||
793 dest->is<MatrixOfMatrixPtrStmt>());
794 TI_STMT_REG_FIELDS;
795 }
796
797 bool has_global_side_effect() const override {
798 return false;
799 }
800
801 bool dead_instruction_eliminable() const override {
802 return false;
803 }
804
805 bool common_statement_eliminable() const override {
806 return false;
807 }
808
809 // IR Trait: Store
810 stmt_refs get_store_destination() const override {
811 return dest;
812 }
813
814 Stmt *get_store_data() const override {
815 return val;
816 }
817
818 TI_STMT_DEF_FIELDS(ret_type, dest, val);
819 TI_DEFINE_ACCEPT_AND_CLONE;
820};
821
822/**
823 * Same as "if (cond) true_statements; else false_statements;" in C++.
824 * |true_mask| and |false_mask| are used to support vectorization.
825 */
826class IfStmt : public Stmt {
827 public:
828 Stmt *cond;
829 std::unique_ptr<Block> true_statements, false_statements;
830
831 explicit IfStmt(Stmt *cond);
832
833 // Use these setters to set Block::parent_stmt at the same time.
834 void set_true_statements(std::unique_ptr<Block> &&new_true_statements);
835 void set_false_statements(std::unique_ptr<Block> &&new_false_statements);
836
837 bool is_container_statement() const override {
838 return true;
839 }
840
841 std::unique_ptr<Stmt> clone() const override;
842
843 TI_STMT_DEF_FIELDS(cond);
844 TI_DEFINE_ACCEPT
845};
846
847/**
848 * Print the contents in this statement. Each entry in the contents can be
849 * either a statement or a string, and they are printed one by one, separated
850 * by a comma and a space.
851 */
852class PrintStmt : public Stmt {
853 public:
854 using EntryType = std::variant<Stmt *, std::string>;
855 std::vector<EntryType> contents;
856
857 explicit PrintStmt(const std::vector<EntryType> &contents_)
858 : contents(contents_) {
859 TI_STMT_REG_FIELDS;
860 }
861
862 template <typename... Args>
863 explicit PrintStmt(Stmt *t, Args &&...args)
864 : contents(make_entries(t, std::forward<Args>(args)...)) {
865 TI_STMT_REG_FIELDS;
866 }
867
868 template <typename... Args>
869 explicit PrintStmt(const std::string &str, Args &&...args)
870 : contents(make_entries(str, std::forward<Args>(args)...)) {
871 TI_STMT_REG_FIELDS;
872 }
873
874 TI_STMT_DEF_FIELDS(ret_type, contents);
875 TI_DEFINE_ACCEPT_AND_CLONE
876
877 private:
878 static void make_entries_helper(std::vector<PrintStmt::EntryType> &entries) {
879 }
880
881 template <typename T, typename... Args>
882 static void make_entries_helper(std::vector<PrintStmt::EntryType> &entries,
883 T &&t,
884 Args &&...values) {
885 entries.push_back(EntryType{t});
886 make_entries_helper(entries, std::forward<Args>(values)...);
887 }
888
889 template <typename... Args>
890 static std::vector<EntryType> make_entries(Args &&...values) {
891 std::vector<EntryType> ret;
892 make_entries_helper(ret, std::forward<Args>(values)...);
893 return ret;
894 }
895};
896
897/**
898 * A constant value.
899 */
900class ConstStmt : public Stmt {
901 public:
902 TypedConstant val;
903
904 explicit ConstStmt(const TypedConstant &val) : val(val) {
905 ret_type = val.dt;
906 TI_STMT_REG_FIELDS;
907 }
908
909 bool has_global_side_effect() const override {
910 return false;
911 }
912
913 TI_STMT_DEF_FIELDS(ret_type, val);
914 TI_DEFINE_ACCEPT_AND_CLONE
915};
916
917/**
918 * A general range for, similar to "for (i = begin; i < end; i++) body;" in C++.
919 * When |reversed| is true, the for loop is reversed, i.e.,
920 * "for (i = end - 1; i >= begin; i--) body;".
921 * When the statement is in the top level before offloading, it will be
922 * offloaded to a parallel for loop. Otherwise, it will be offloaded to a
923 * serial for loop.
924 */
925class RangeForStmt : public Stmt {
926 public:
927 Stmt *begin, *end;
928 std::unique_ptr<Block> body;
929 bool reversed;
930 bool is_bit_vectorized;
931 int num_cpu_threads;
932 int block_dim;
933 bool strictly_serialized;
934 std::string range_hint;
935
936 RangeForStmt(Stmt *begin,
937 Stmt *end,
938 std::unique_ptr<Block> &&body,
939 bool is_bit_vectorized,
940 int num_cpu_threads,
941 int block_dim,
942 bool strictly_serialized,
943 std::string range_hint = "");
944
945 bool is_container_statement() const override {
946 return true;
947 }
948
949 void reverse() {
950 reversed = !reversed;
951 }
952
953 std::unique_ptr<Stmt> clone() const override;
954
955 TI_STMT_DEF_FIELDS(begin,
956 end,
957 reversed,
958 is_bit_vectorized,
959 num_cpu_threads,
960 block_dim,
961 strictly_serialized);
962 TI_DEFINE_ACCEPT
963};
964
965/**
966 * A parallel for loop over a SNode, similar to "for i in snode: body"
967 * in Python. This statement must be at the top level before offloading.
968 */
969class StructForStmt : public Stmt {
970 public:
971 SNode *snode;
972 std::unique_ptr<Block> body;
973 std::unique_ptr<Block> block_initialization;
974 std::unique_ptr<Block> block_finalization;
975 std::vector<int> index_offsets;
976 bool is_bit_vectorized;
977 int num_cpu_threads;
978 int block_dim;
979 MemoryAccessOptions mem_access_opt;
980
981 StructForStmt(SNode *snode,
982 std::unique_ptr<Block> &&body,
983 bool is_bit_vectorized,
984 int num_cpu_threads,
985 int block_dim);
986
987 bool is_container_statement() const override {
988 return true;
989 }
990
991 std::unique_ptr<Stmt> clone() const override;
992
993 TI_STMT_DEF_FIELDS(snode,
994 index_offsets,
995 is_bit_vectorized,
996 num_cpu_threads,
997 block_dim,
998 mem_access_opt);
999 TI_DEFINE_ACCEPT
1000};
1001
1002/**
1003 * meshfor
1004 */
1005class MeshForStmt : public Stmt {
1006 public:
1007 mesh::Mesh *mesh;
1008 std::unique_ptr<Block> body;
1009 bool is_bit_vectorized;
1010 int num_cpu_threads;
1011 int block_dim;
1012 mesh::MeshElementType major_from_type;
1013 std::unordered_set<mesh::MeshElementType> major_to_types{};
1014 std::unordered_set<mesh::MeshRelationType> minor_relation_types{};
1015 MemoryAccessOptions mem_access_opt;
1016
1017 MeshForStmt(mesh::Mesh *mesh,
1018 mesh::MeshElementType element_type,
1019 std::unique_ptr<Block> &&body,
1020 bool is_bit_vectorized,
1021 int num_cpu_threads,
1022 int block_dim);
1023
1024 bool is_container_statement() const override {
1025 return true;
1026 }
1027
1028 std::unique_ptr<Stmt> clone() const override;
1029
1030 TI_STMT_DEF_FIELDS(mesh,
1031 is_bit_vectorized,
1032 num_cpu_threads,
1033 block_dim,
1034 major_from_type,
1035 major_to_types,
1036 minor_relation_types,
1037 mem_access_opt);
1038 TI_DEFINE_ACCEPT
1039};
1040
1041/**
1042 * Call an inline Taichi function.
1043 */
1044class FuncCallStmt : public Stmt {
1045 public:
1046 Function *func;
1047 std::vector<Stmt *> args;
1048 bool global_side_effect{true};
1049
1050 FuncCallStmt(Function *func, const std::vector<Stmt *> &args);
1051
1052 bool has_global_side_effect() const override {
1053 return global_side_effect;
1054 }
1055
1056 TI_STMT_DEF_FIELDS(ret_type, func, args);
1057 TI_DEFINE_ACCEPT_AND_CLONE
1058};
1059
1060/**
1061 * A reference to a variable.
1062 */
1063class ReferenceStmt : public Stmt, public ir_traits::Load {
1064 public:
1065 Stmt *var;
1066 bool global_side_effect{false};
1067
1068 explicit ReferenceStmt(Stmt *var) : var(var) {
1069 TI_STMT_REG_FIELDS;
1070 }
1071
1072 bool has_global_side_effect() const override {
1073 return global_side_effect;
1074 }
1075
1076 // IR Trait: Load
1077 stmt_refs get_load_pointers() const override {
1078 return var;
1079 }
1080
1081 TI_STMT_DEF_FIELDS(ret_type, var);
1082 TI_DEFINE_ACCEPT_AND_CLONE
1083};
1084
1085/**
1086 * Gets an element from a struct
1087 */
1088class GetElementStmt : public Stmt {
1089 public:
1090 Stmt *src;
1091 std::vector<int> index;
1092 GetElementStmt(Stmt *src, const std::vector<int> &index)
1093 : src(src), index(index) {
1094 TI_STMT_REG_FIELDS;
1095 }
1096
1097 TI_STMT_DEF_FIELDS(ret_type, src, index);
1098 TI_DEFINE_ACCEPT_AND_CLONE
1099};
1100
1101/**
1102 * Exit the kernel or function with a return value.
1103 */
1104class ReturnStmt : public Stmt {
1105 public:
1106 std::vector<Stmt *> values;
1107
1108 explicit ReturnStmt(const std::vector<Stmt *> &values) : values(values) {
1109 TI_STMT_REG_FIELDS;
1110 }
1111
1112 explicit ReturnStmt(Stmt *value) : values({value}) {
1113 TI_STMT_REG_FIELDS;
1114 }
1115
1116 std::vector<DataType> element_types() {
1117 std::vector<DataType> ele_types;
1118 for (auto &x : values) {
1119 ele_types.push_back(x->element_type());
1120 }
1121 return ele_types;
1122 }
1123
1124 std::string values_raw_names() {
1125 std::string names;
1126 for (auto &x : values) {
1127 names += x->raw_name() + ", ";
1128 }
1129 names.pop_back();
1130 names.pop_back();
1131 return names;
1132 }
1133
1134 TI_STMT_DEF_FIELDS(values);
1135 TI_DEFINE_ACCEPT_AND_CLONE
1136};
1137
1138/**
1139 * A serial while-true loop. |mask| is to support vectorization.
1140 */
1141class WhileStmt : public Stmt {
1142 public:
1143 Stmt *mask;
1144 std::unique_ptr<Block> body;
1145
1146 explicit WhileStmt(std::unique_ptr<Block> &&body);
1147
1148 bool is_container_statement() const override {
1149 return true;
1150 }
1151
1152 std::unique_ptr<Stmt> clone() const override;
1153
1154 TI_STMT_DEF_FIELDS(mask);
1155 TI_DEFINE_ACCEPT
1156};
1157
1158// TODO: remove this (replace with input + ConstStmt(offset))
1159class IntegerOffsetStmt : public Stmt {
1160 public:
1161 Stmt *input;
1162 int64 offset;
1163
1164 IntegerOffsetStmt(Stmt *input, int64 offset) : input(input), offset(offset) {
1165 TI_STMT_REG_FIELDS;
1166 }
1167
1168 bool has_global_side_effect() const override {
1169 return false;
1170 }
1171
1172 TI_STMT_DEF_FIELDS(ret_type, input, offset);
1173 TI_DEFINE_ACCEPT_AND_CLONE
1174};
1175
1176/**
1177 * All indices of an address fused together.
1178 */
1179class LinearizeStmt : public Stmt {
1180 public:
1181 std::vector<Stmt *> inputs;
1182 std::vector<int> strides;
1183
1184 LinearizeStmt(const std::vector<Stmt *> &inputs,
1185 const std::vector<int> &strides)
1186 : inputs(inputs), strides(strides) {
1187 TI_ASSERT(inputs.size() == strides.size());
1188 TI_STMT_REG_FIELDS;
1189 }
1190
1191 bool has_global_side_effect() const override {
1192 return false;
1193 }
1194
1195 TI_STMT_DEF_FIELDS(ret_type, inputs, strides);
1196 TI_DEFINE_ACCEPT_AND_CLONE
1197};
1198
1199/**
1200 * The SNode root.
1201 */
1202class GetRootStmt : public Stmt {
1203 public:
1204 explicit GetRootStmt(SNode *root = nullptr) : root_(root) {
1205 if (this->root_ != nullptr) {
1206 while (this->root_->parent) {
1207 this->root_ = this->root_->parent;
1208 }
1209 }
1210 TI_STMT_REG_FIELDS;
1211 }
1212
1213 bool has_global_side_effect() const override {
1214 return false;
1215 }
1216
1217 TI_STMT_DEF_FIELDS(ret_type, root_);
1218 TI_DEFINE_ACCEPT_AND_CLONE
1219
1220 SNode *root() {
1221 return root_;
1222 }
1223
1224 const SNode *root() const {
1225 return root_;
1226 }
1227
1228 private:
1229 SNode *root_;
1230};
1231
1232/**
1233 * Lookup a component of a SNode.
1234 */
1235class SNodeLookupStmt : public Stmt {
1236 public:
1237 SNode *snode;
1238 Stmt *input_snode;
1239 Stmt *input_index;
1240 bool activate;
1241
1242 SNodeLookupStmt(SNode *snode,
1243 Stmt *input_snode,
1244 Stmt *input_index,
1245 bool activate)
1246 : snode(snode),
1247 input_snode(input_snode),
1248 input_index(input_index),
1249 activate(activate) {
1250 TI_STMT_REG_FIELDS;
1251 }
1252
1253 bool has_global_side_effect() const override {
1254 return activate;
1255 }
1256
1257 bool common_statement_eliminable() const override {
1258 return true;
1259 }
1260
1261 TI_STMT_DEF_FIELDS(ret_type, snode, input_snode, input_index, activate);
1262 TI_DEFINE_ACCEPT_AND_CLONE
1263};
1264
1265/**
1266 * Get a child of a SNode on the hierarchical SNode tree.
1267 */
1268class GetChStmt : public Stmt {
1269 public:
1270 Stmt *input_ptr;
1271 SNode *input_snode, *output_snode;
1272 int chid;
1273 bool is_bit_vectorized;
1274
1275 GetChStmt(Stmt *input_ptr, int chid, bool is_bit_vectorized = false);
1276 GetChStmt(Stmt *input_ptr,
1277 SNode *snode,
1278 int chid,
1279 bool is_bit_vectorized = false);
1280
1281 bool has_global_side_effect() const override {
1282 return false;
1283 }
1284
1285 TI_STMT_DEF_FIELDS(ret_type,
1286 input_ptr,
1287 input_snode,
1288 output_snode,
1289 chid,
1290 is_bit_vectorized);
1291 TI_DEFINE_ACCEPT_AND_CLONE
1292};
1293
1294/**
1295 * The statement corresponding to an offloaded task.
1296 */
1297class OffloadedStmt : public Stmt {
1298 public:
1299 using TaskType = OffloadedTaskType;
1300
1301 TaskType task_type;
1302 Arch device;
1303 SNode *snode{nullptr};
1304 std::size_t begin_offset{0};
1305 std::size_t end_offset{0};
1306 bool const_begin{false};
1307 bool const_end{false};
1308 int32 begin_value{0};
1309 int32 end_value{0};
1310 int grid_dim{1};
1311 int block_dim{1};
1312 bool reversed{false};
1313 bool is_bit_vectorized{false};
1314 int num_cpu_threads{1};
1315 Stmt *end_stmt{nullptr};
1316 std::string range_hint = "";
1317
1318 mesh::Mesh *mesh{nullptr};
1319 mesh::MeshElementType major_from_type;
1320 std::unordered_set<mesh::MeshElementType> major_to_types;
1321 std::unordered_set<mesh::MeshRelationType> minor_relation_types;
1322
1323 std::unordered_map<mesh::MeshElementType, Stmt *>
1324 owned_offset_local; // |owned_offset[idx]|
1325 std::unordered_map<mesh::MeshElementType, Stmt *>
1326 total_offset_local; // |total_offset[idx]|
1327 std::unordered_map<mesh::MeshElementType, Stmt *>
1328 owned_num_local; // |owned_offset[idx+1] - owned_offset[idx]|
1329 std::unordered_map<mesh::MeshElementType, Stmt *>
1330 total_num_local; // |total_offset[idx+1] - total_offset[idx]|
1331
1332 std::vector<int> index_offsets;
1333
1334 std::unique_ptr<Block> tls_prologue;
1335 std::unique_ptr<Block> mesh_prologue; // mesh-for only block
1336 std::unique_ptr<Block> bls_prologue;
1337 std::unique_ptr<Block> body;
1338 std::unique_ptr<Block> bls_epilogue;
1339 std::unique_ptr<Block> tls_epilogue;
1340 std::size_t tls_size{1}; // avoid allocating dynamic memory with 0 byte
1341 std::size_t bls_size{0};
1342 MemoryAccessOptions mem_access_opt;
1343
1344 OffloadedStmt(TaskType task_type, Arch arch);
1345
1346 std::string task_name() const;
1347
1348 static std::string task_type_name(TaskType tt);
1349
1350 bool has_body() const {
1351 return task_type != TaskType::listgen && task_type != TaskType::gc &&
1352 task_type != TaskType::gc_rc;
1353 }
1354
1355 bool is_container_statement() const override {
1356 return has_body();
1357 }
1358
1359 std::unique_ptr<Stmt> clone() const override;
1360
1361 void all_blocks_accept(IRVisitor *visitor, bool skip_mesh_prologue = false);
1362
1363 TI_STMT_DEF_FIELDS(ret_type /*inherited from Stmt*/,
1364 task_type,
1365 device,
1366 snode,
1367 begin_offset,
1368 end_offset,
1369 const_begin,
1370 const_end,
1371 begin_value,
1372 end_value,
1373 grid_dim,
1374 block_dim,
1375 reversed,
1376 num_cpu_threads,
1377 index_offsets,
1378 mem_access_opt);
1379 TI_DEFINE_ACCEPT
1380};
1381
1382/**
1383 * The |index|-th index of the |loop|.
1384 */
1385class LoopIndexStmt : public Stmt {
1386 public:
1387 Stmt *loop;
1388 int index;
1389
1390 LoopIndexStmt(Stmt *loop, int index) : loop(loop), index(index) {
1391 TI_STMT_REG_FIELDS;
1392 }
1393
1394 bool is_mesh_index() const {
1395 if (auto offload = loop->cast<OffloadedStmt>()) {
1396 return offload->task_type == OffloadedTaskType::mesh_for;
1397 } else if (loop->cast<MeshForStmt>()) {
1398 return true;
1399 } else {
1400 return false;
1401 }
1402 }
1403
1404 mesh::MeshElementType mesh_index_type() const {
1405 TI_ASSERT(is_mesh_index());
1406 if (auto offload = loop->cast<OffloadedStmt>()) {
1407 return offload->major_from_type;
1408 } else if (auto mesh_for = loop->cast<MeshForStmt>()) {
1409 return mesh_for->major_from_type;
1410 } else {
1411 TI_NOT_IMPLEMENTED;
1412 }
1413 }
1414
1415 bool has_global_side_effect() const override {
1416 return false;
1417 }
1418
1419 TI_STMT_DEF_FIELDS(ret_type, loop, index);
1420 TI_DEFINE_ACCEPT_AND_CLONE
1421};
1422
1423/**
1424 * thread index within a CUDA block
1425 * TODO: Remove this. Have a better way for retrieving thread index.
1426 */
1427class LoopLinearIndexStmt : public Stmt {
1428 public:
1429 Stmt *loop;
1430
1431 explicit LoopLinearIndexStmt(Stmt *loop) : loop(loop) {
1432 TI_STMT_REG_FIELDS;
1433 }
1434
1435 bool has_global_side_effect() const override {
1436 return false;
1437 }
1438
1439 TI_STMT_DEF_FIELDS(ret_type, loop);
1440 TI_DEFINE_ACCEPT_AND_CLONE
1441};
1442
1443/**
1444 * global thread index, i.e. thread_idx() + block_idx() * block_dim()
1445 */
1446class GlobalThreadIndexStmt : public Stmt {
1447 public:
1448 explicit GlobalThreadIndexStmt() {
1449 TI_STMT_REG_FIELDS;
1450 }
1451
1452 bool has_global_side_effect() const override {
1453 return false;
1454 }
1455
1456 TI_STMT_DEF_FIELDS(ret_type);
1457 TI_DEFINE_ACCEPT_AND_CLONE
1458};
1459
1460/**
1461 * The lowest |index|-th index of the |loop| among the iterations iterated by
1462 * the block.
1463 */
1464class BlockCornerIndexStmt : public Stmt {
1465 public:
1466 Stmt *loop;
1467 int index;
1468
1469 BlockCornerIndexStmt(Stmt *loop, int index) : loop(loop), index(index) {
1470 TI_STMT_REG_FIELDS;
1471 }
1472
1473 bool has_global_side_effect() const override {
1474 return false;
1475 }
1476
1477 TI_STMT_DEF_FIELDS(ret_type, loop, index);
1478 TI_DEFINE_ACCEPT_AND_CLONE
1479};
1480
1481/**
1482 * A global temporary variable, located at |offset| in the global temporary
1483 * buffer.
1484 */
1485class GlobalTemporaryStmt : public Stmt {
1486 public:
1487 std::size_t offset;
1488
1489 GlobalTemporaryStmt(std::size_t offset, const DataType &ret_type)
1490 : offset(offset) {
1491 this->ret_type = ret_type;
1492 TI_STMT_REG_FIELDS;
1493 }
1494
1495 bool has_global_side_effect() const override {
1496 return false;
1497 }
1498
1499 TI_STMT_DEF_FIELDS(ret_type, offset);
1500 TI_DEFINE_ACCEPT_AND_CLONE
1501};
1502
1503/**
1504 * A thread-local pointer, located at |offset| in the thread-local storage.
1505 */
1506class ThreadLocalPtrStmt : public Stmt {
1507 public:
1508 std::size_t offset;
1509
1510 ThreadLocalPtrStmt(std::size_t offset, const DataType &ret_type)
1511 : offset(offset) {
1512 this->ret_type = ret_type;
1513 TI_STMT_REG_FIELDS;
1514 }
1515
1516 bool has_global_side_effect() const override {
1517 return false;
1518 }
1519
1520 TI_STMT_DEF_FIELDS(ret_type, offset);
1521 TI_DEFINE_ACCEPT_AND_CLONE
1522};
1523
1524/**
1525 * A block-local pointer, located at |offset| in the block-local storage.
1526 */
1527class BlockLocalPtrStmt : public Stmt {
1528 public:
1529 Stmt *offset;
1530
1531 BlockLocalPtrStmt(Stmt *offset, const DataType &ret_type) : offset(offset) {
1532 this->ret_type = ret_type;
1533 TI_STMT_REG_FIELDS;
1534 }
1535
1536 bool has_global_side_effect() const override {
1537 return false;
1538 }
1539
1540 TI_STMT_DEF_FIELDS(ret_type, offset);
1541 TI_DEFINE_ACCEPT_AND_CLONE
1542};
1543
1544/**
1545 * The statement corresponding to a clear-list task.
1546 */
1547class ClearListStmt : public Stmt {
1548 public:
1549 explicit ClearListStmt(SNode *snode);
1550
1551 SNode *snode;
1552
1553 TI_STMT_DEF_FIELDS(ret_type, snode);
1554 TI_DEFINE_ACCEPT_AND_CLONE
1555};
1556
1557// Checks if the task represented by |stmt| contains a single ClearListStmt.
1558bool is_clear_list_task(const OffloadedStmt *stmt);
1559
1560class InternalFuncStmt : public Stmt {
1561 public:
1562 std::string func_name;
1563 std::vector<Stmt *> args;
1564 bool with_runtime_context;
1565
1566 explicit InternalFuncStmt(const std::string &func_name,
1567 const std::vector<Stmt *> &args,
1568 Type *ret_type = nullptr,
1569 bool with_runtime_context = true)
1570 : func_name(func_name),
1571 args(args),
1572 with_runtime_context(with_runtime_context) {
1573 if (ret_type == nullptr) {
1574 this->ret_type = PrimitiveType::i32;
1575 } else {
1576 this->ret_type = ret_type;
1577 }
1578 TI_STMT_REG_FIELDS;
1579 }
1580
1581 TI_STMT_DEF_FIELDS(ret_type, func_name, args, with_runtime_context);
1582 TI_DEFINE_ACCEPT_AND_CLONE
1583};
1584
1585class Texture;
1586
1587class TexturePtrStmt : public Stmt {
1588 public:
1589 Stmt *arg_load_stmt{nullptr};
1590 int dimensions{2};
1591 bool is_storage{false};
1592
1593 // Optional, for storage textures
1594 int num_channels{0};
1595 DataType channel_format{PrimitiveType::f32};
1596 int lod{0};
1597
1598 explicit TexturePtrStmt(Stmt *stmt,
1599 int dimensions,
1600 bool is_storage,
1601 int num_channels,
1602 DataType channel_format,
1603 int lod)
1604 : arg_load_stmt(stmt),
1605 dimensions(dimensions),
1606 is_storage(is_storage),
1607 num_channels(num_channels),
1608 channel_format(channel_format),
1609 lod(lod) {
1610 TI_STMT_REG_FIELDS;
1611 }
1612
1613 explicit TexturePtrStmt(Stmt *stmt, int dimensions)
1614 : arg_load_stmt(stmt), dimensions(dimensions), is_storage(false) {
1615 TI_STMT_REG_FIELDS;
1616 }
1617
1618 TI_STMT_DEF_FIELDS(arg_load_stmt,
1619 dimensions,
1620 is_storage,
1621 num_channels,
1622 channel_format,
1623 lod);
1624 TI_DEFINE_ACCEPT_AND_CLONE
1625};
1626
1627class TextureOpStmt : public Stmt {
1628 public:
1629 TextureOpType op;
1630 Stmt *texture_ptr;
1631 std::vector<Stmt *> args;
1632
1633 explicit TextureOpStmt(TextureOpType op,
1634 Stmt *texture_ptr,
1635 const std::vector<Stmt *> &args)
1636 : op(op), texture_ptr(texture_ptr), args(args) {
1637 TI_STMT_REG_FIELDS;
1638 }
1639
1640 /*
1641 bool has_global_side_effect() const override {
1642 return op == TextureOpType::kStore;
1643 }
1644 */
1645
1646 bool common_statement_eliminable() const override {
1647 return op != TextureOpType::kStore;
1648 }
1649
1650 TI_STMT_DEF_FIELDS(op, texture_ptr, args);
1651 TI_DEFINE_ACCEPT_AND_CLONE
1652};
1653
1654/**
1655 * A local AD-stack.
1656 */
1657class AdStackAllocaStmt : public Stmt {
1658 public:
1659 DataType dt;
1660 std::size_t max_size{0}; // 0 = adaptive
1661
1662 AdStackAllocaStmt(const DataType &dt, std::size_t max_size)
1663 : dt(dt), max_size(max_size) {
1664 TI_STMT_REG_FIELDS;
1665 }
1666
1667 std::size_t element_size_in_bytes() const {
1668 return data_type_size(ret_type);
1669 }
1670
1671 std::size_t entry_size_in_bytes() const {
1672 return element_size_in_bytes() * 2;
1673 }
1674
1675 std::size_t size_in_bytes() const {
1676 return sizeof(int32) + entry_size_in_bytes() * max_size;
1677 }
1678
1679 bool has_global_side_effect() const override {
1680 return false;
1681 }
1682
1683 bool common_statement_eliminable() const override {
1684 return false;
1685 }
1686
1687 TI_STMT_DEF_FIELDS(ret_type, dt, max_size);
1688 TI_DEFINE_ACCEPT_AND_CLONE
1689};
1690
1691/**
1692 * Load the top primal value of an AD-stack.
1693 */
1694class AdStackLoadTopStmt : public Stmt, public ir_traits::Load {
1695 public:
1696 Stmt *stack;
1697
1698 explicit AdStackLoadTopStmt(Stmt *stack) {
1699 TI_ASSERT(stack->is<AdStackAllocaStmt>());
1700 this->stack = stack;
1701 TI_STMT_REG_FIELDS;
1702 }
1703
1704 bool has_global_side_effect() const override {
1705 return false;
1706 }
1707
1708 bool common_statement_eliminable() const override {
1709 return false;
1710 }
1711
1712 // IR Trait: Load
1713 stmt_refs get_load_pointers() const override {
1714 return stack;
1715 }
1716
1717 TI_STMT_DEF_FIELDS(ret_type, stack);
1718 TI_DEFINE_ACCEPT_AND_CLONE
1719};
1720
1721/**
1722 * Load the top adjoint value of an AD-stack.
1723 */
1724class AdStackLoadTopAdjStmt : public Stmt, public ir_traits::Load {
1725 public:
1726 Stmt *stack;
1727
1728 explicit AdStackLoadTopAdjStmt(Stmt *stack) {
1729 TI_ASSERT(stack->is<AdStackAllocaStmt>());
1730 this->stack = stack;
1731 TI_STMT_REG_FIELDS;
1732 }
1733
1734 bool has_global_side_effect() const override {
1735 return false;
1736 }
1737
1738 bool common_statement_eliminable() const override {
1739 return false;
1740 }
1741
1742 // IR Trait: Load
1743 stmt_refs get_load_pointers() const override {
1744 return stack;
1745 }
1746
1747 TI_STMT_DEF_FIELDS(ret_type, stack);
1748 TI_DEFINE_ACCEPT_AND_CLONE
1749};
1750
1751/**
1752 * Pop the top primal and adjoint values in the AD-stack.
1753 */
1754class AdStackPopStmt : public Stmt, public ir_traits::Load {
1755 public:
1756 Stmt *stack;
1757
1758 explicit AdStackPopStmt(Stmt *stack) {
1759 TI_ASSERT(stack->is<AdStackAllocaStmt>());
1760 this->stack = stack;
1761 TI_STMT_REG_FIELDS;
1762 }
1763
1764 // IR Trait: Load
1765 stmt_refs get_load_pointers() const override {
1766 // This is to make dead store elimination not eliminate consequent pops.
1767 return stack;
1768 }
1769
1770 // Mark has_global_side_effect == true to prevent being moved out of an if
1771 // clause in the simplify pass for now.
1772
1773 TI_STMT_DEF_FIELDS(ret_type, stack);
1774 TI_DEFINE_ACCEPT_AND_CLONE
1775};
1776
1777/**
1778 * Push a primal value to the AD-stack, and set the corresponding adjoint
1779 * value to 0.
1780 */
1781class AdStackPushStmt : public Stmt, public ir_traits::Load {
1782 public:
1783 Stmt *stack;
1784 Stmt *v;
1785
1786 AdStackPushStmt(Stmt *stack, Stmt *v) {
1787 TI_ASSERT(stack->is<AdStackAllocaStmt>());
1788 this->stack = stack;
1789 this->v = v;
1790 TI_STMT_REG_FIELDS;
1791 }
1792
1793 // IR Trait: Load
1794 stmt_refs get_load_pointers() const override {
1795 // This is to make dead store elimination not eliminate consequent pushes.
1796 return stack;
1797 }
1798
1799 // Mark has_global_side_effect == true to prevent being moved out of an if
1800 // clause in the simplify pass for now.
1801
1802 TI_STMT_DEF_FIELDS(ret_type, stack, v);
1803 TI_DEFINE_ACCEPT_AND_CLONE
1804};
1805
1806/**
1807 * Accumulate |v| to the top adjoint value of the AD-stack.
1808 * This statement loads and stores the adjoint data.
1809 */
1810class AdStackAccAdjointStmt : public Stmt, public ir_traits::Load {
1811 public:
1812 Stmt *stack;
1813 Stmt *v;
1814
1815 AdStackAccAdjointStmt(Stmt *stack, Stmt *v) {
1816 TI_ASSERT(stack->is<AdStackAllocaStmt>());
1817 this->stack = stack;
1818 this->v = v;
1819 TI_STMT_REG_FIELDS;
1820 }
1821
1822 // IR Trait: Load
1823 stmt_refs get_load_pointers() const override {
1824 return stack;
1825 }
1826
1827 // Mark has_global_side_effect == true to prevent being moved out of an if
1828 // clause in the simplify pass for now.
1829
1830 TI_STMT_DEF_FIELDS(ret_type, stack, v);
1831 TI_DEFINE_ACCEPT_AND_CLONE
1832};
1833
1834/**
1835 * A global store to one or more children of a bit struct.
1836 */
1837class BitStructStoreStmt : public Stmt {
1838 public:
1839 Stmt *ptr;
1840 std::vector<int> ch_ids;
1841 std::vector<Stmt *> values;
1842 bool is_atomic;
1843
1844 BitStructStoreStmt(Stmt *ptr,
1845 const std::vector<int> &ch_ids,
1846 const std::vector<Stmt *> &values)
1847 : ptr(ptr), ch_ids(ch_ids), values(values), is_atomic(true) {
1848 TI_ASSERT(ch_ids.size() == values.size());
1849 TI_STMT_REG_FIELDS;
1850 }
1851
1852 BitStructType *get_bit_struct() const;
1853
1854 bool common_statement_eliminable() const override {
1855 return false;
1856 }
1857
1858 TI_STMT_DEF_FIELDS(ret_type, ptr, ch_ids, values, is_atomic);
1859 TI_DEFINE_ACCEPT_AND_CLONE;
1860};
1861
1862// Mesh related.
1863
1864/**
1865 * The relation access, mesh_idx -> to_type[neighbor_idx]
1866 * If neibhor_idex has no value, it returns the number of neighbors (length of
1867 * relation) of a mesh idx
1868 */
1869class MeshRelationAccessStmt : public Stmt {
1870 public:
1871 mesh::Mesh *mesh;
1872 Stmt *mesh_idx;
1873 mesh::MeshElementType to_type;
1874 Stmt *neighbor_idx;
1875
1876 MeshRelationAccessStmt(mesh::Mesh *mesh,
1877 Stmt *mesh_idx,
1878 mesh::MeshElementType to_type,
1879 Stmt *neighbor_idx)
1880 : mesh(mesh),
1881 mesh_idx(mesh_idx),
1882 to_type(to_type),
1883 neighbor_idx(neighbor_idx) {
1884 this->ret_type = PrimitiveType::u16;
1885 TI_STMT_REG_FIELDS;
1886 }
1887
1888 MeshRelationAccessStmt(mesh::Mesh *mesh,
1889 Stmt *mesh_idx,
1890 mesh::MeshElementType to_type)
1891 : mesh(mesh),
1892 mesh_idx(mesh_idx),
1893 to_type(to_type),
1894 neighbor_idx(nullptr) {
1895 this->ret_type = PrimitiveType::u16;
1896 TI_STMT_REG_FIELDS;
1897 }
1898
1899 bool is_size() const {
1900 return neighbor_idx == nullptr;
1901 }
1902
1903 bool has_global_side_effect() const override {
1904 return false;
1905 }
1906
1907 mesh::MeshElementType from_type() const {
1908 if (auto idx = mesh_idx->cast<LoopIndexStmt>()) {
1909 TI_ASSERT(idx->is_mesh_index());
1910 return idx->mesh_index_type();
1911 } else if (auto idx = mesh_idx->cast<MeshRelationAccessStmt>()) {
1912 TI_ASSERT(!idx->is_size());
1913 return idx->to_type;
1914 } else {
1915 TI_NOT_IMPLEMENTED;
1916 }
1917 }
1918
1919 TI_STMT_DEF_FIELDS(ret_type, mesh, mesh_idx, to_type, neighbor_idx);
1920 TI_DEFINE_ACCEPT_AND_CLONE
1921};
1922
1923/**
1924 * Convert a mesh index to another index space
1925 */
1926class MeshIndexConversionStmt : public Stmt {
1927 public:
1928 mesh::Mesh *mesh;
1929 mesh::MeshElementType idx_type;
1930 Stmt *idx;
1931
1932 mesh::ConvType conv_type;
1933
1934 MeshIndexConversionStmt(mesh::Mesh *mesh,
1935 mesh::MeshElementType idx_type,
1936 Stmt *idx,
1937 mesh::ConvType conv_type)
1938 : mesh(mesh), idx_type(idx_type), idx(idx), conv_type(conv_type) {
1939 this->ret_type = PrimitiveType::i32;
1940 TI_STMT_REG_FIELDS;
1941 }
1942
1943 bool has_global_side_effect() const override {
1944 return false;
1945 }
1946
1947 TI_STMT_DEF_FIELDS(ret_type, mesh, idx_type, idx, conv_type);
1948 TI_DEFINE_ACCEPT_AND_CLONE
1949};
1950
1951/**
1952 * The patch index of the |mesh_loop|.
1953 */
1954class MeshPatchIndexStmt : public Stmt {
1955 public:
1956 MeshPatchIndexStmt() {
1957 this->ret_type = PrimitiveType::i32;
1958 TI_STMT_REG_FIELDS;
1959 }
1960
1961 bool has_global_side_effect() const override {
1962 return false;
1963 }
1964
1965 TI_STMT_DEF_FIELDS(ret_type);
1966 TI_DEFINE_ACCEPT_AND_CLONE
1967};
1968
1969/**
1970 * Initialization of a local matrix
1971 */
1972class MatrixInitStmt : public Stmt {
1973 public:
1974 std::vector<Stmt *> values;
1975
1976 explicit MatrixInitStmt(const std::vector<Stmt *> &values) : values(values) {
1977 TI_STMT_REG_FIELDS;
1978 }
1979
1980 bool has_global_side_effect() const override {
1981 return false;
1982 }
1983
1984 TI_STMT_DEF_FIELDS(ret_type, values);
1985 TI_DEFINE_ACCEPT_AND_CLONE
1986};
1987
1988} // namespace taichi::lang
1989