1#pragma once
2
3#include <string>
4#include <vector>
5
6#include "taichi/ir/snode_types.h"
7#include "taichi/ir/stmt_op_types.h"
8#include "taichi/ir/ir.h"
9#include "taichi/ir/expression.h"
10#include "taichi/rhi/arch.h"
11#include "taichi/program/function.h"
12#include "taichi/ir/mesh.h"
13
14namespace taichi::lang {
15
16class ASTBuilder;
17
18struct ForLoopConfig {
19 bool is_bit_vectorized{false};
20 int num_cpu_threads{0};
21 bool strictly_serialized{false};
22 MemoryAccessOptions mem_access_opt;
23 int block_dim{0};
24 bool uniform{false};
25};
26
27// Frontend Statements
28class FrontendExternalFuncStmt : public Stmt {
29 public:
30 void *so_func;
31 std::string asm_source;
32 std::string bc_filename;
33 std::string bc_funcname;
34 std::vector<Expr> args;
35 std::vector<Expr> outputs;
36
37 FrontendExternalFuncStmt(void *so_func,
38 const std::string &asm_source,
39 const std::string &bc_filename,
40 const std::string &bc_funcname,
41 const std::vector<Expr> &args,
42 const std::vector<Expr> &outputs)
43 : so_func(so_func),
44 asm_source(asm_source),
45 bc_filename(bc_filename),
46 bc_funcname(bc_funcname),
47 args(args),
48 outputs(outputs) {
49 }
50
51 TI_DEFINE_ACCEPT
52};
53
54class FrontendExprStmt : public Stmt {
55 public:
56 Expr val;
57
58 explicit FrontendExprStmt(const Expr &val) : val(val) {
59 }
60
61 TI_DEFINE_ACCEPT
62};
63
64class FrontendAllocaStmt : public Stmt {
65 public:
66 Identifier ident;
67
68 FrontendAllocaStmt(const Identifier &lhs, DataType type)
69 : ident(lhs), is_shared(false) {
70 ret_type = type;
71 }
72
73 FrontendAllocaStmt(const Identifier &lhs,
74 std::vector<int> shape,
75 DataType element,
76 bool is_shared = false)
77 : ident(lhs), is_shared(is_shared) {
78 ret_type = DataType(TypeFactory::create_tensor_type(shape, element));
79 }
80
81 bool is_shared;
82
83 TI_DEFINE_ACCEPT
84};
85
86class FrontendSNodeOpStmt : public Stmt {
87 public:
88 SNodeOpType op_type;
89 SNode *snode;
90 ExprGroup indices;
91 Expr val;
92
93 FrontendSNodeOpStmt(SNodeOpType op_type,
94 SNode *snode,
95 const ExprGroup &indices,
96 const Expr &val = Expr(nullptr));
97
98 TI_DEFINE_ACCEPT
99};
100
101class FrontendAssertStmt : public Stmt {
102 public:
103 std::string text;
104 Expr cond;
105 std::vector<Expr> args;
106
107 FrontendAssertStmt(const Expr &cond, const std::string &text)
108 : text(text), cond(cond) {
109 }
110
111 FrontendAssertStmt(const Expr &cond,
112 const std::string &text,
113 const std::vector<Expr> &args_)
114 : text(text), cond(cond) {
115 for (auto &a : args_) {
116 args.push_back(a);
117 }
118 }
119
120 TI_DEFINE_ACCEPT
121};
122
123class FrontendAssignStmt : public Stmt {
124 public:
125 Expr lhs, rhs;
126
127 FrontendAssignStmt(const Expr &lhs, const Expr &rhs);
128
129 TI_DEFINE_ACCEPT
130};
131
132class FrontendIfStmt : public Stmt {
133 public:
134 Expr condition;
135 std::unique_ptr<Block> true_statements, false_statements;
136
137 explicit FrontendIfStmt(const Expr &condition) : condition(condition) {
138 }
139
140 bool is_container_statement() const override {
141 return true;
142 }
143
144 TI_DEFINE_ACCEPT
145};
146
147class FrontendPrintStmt : public Stmt {
148 public:
149 using EntryType = std::variant<Expr, std::string>;
150 std::vector<EntryType> contents;
151
152 explicit FrontendPrintStmt(const std::vector<EntryType> &contents_) {
153 for (const auto &c : contents_) {
154 if (std::holds_alternative<Expr>(c))
155 contents.push_back(std::get<Expr>(c));
156 else
157 contents.push_back(c);
158 }
159 }
160
161 TI_DEFINE_ACCEPT
162};
163
164class FrontendForStmt : public Stmt {
165 public:
166 SNode *snode{nullptr};
167 Expr external_tensor;
168 mesh::Mesh *mesh{nullptr};
169 mesh::MeshElementType element_type;
170 Expr begin, end;
171 std::unique_ptr<Block> body;
172 std::vector<Identifier> loop_var_ids;
173 bool is_bit_vectorized;
174 int num_cpu_threads;
175 bool strictly_serialized;
176 MemoryAccessOptions mem_access_opt;
177 int block_dim;
178
179 FrontendForStmt(const ExprGroup &loop_vars,
180 SNode *snode,
181 Arch arch,
182 const ForLoopConfig &config);
183
184 FrontendForStmt(const ExprGroup &loop_vars,
185 const Expr &external_tensor,
186 Arch arch,
187 const ForLoopConfig &config);
188
189 FrontendForStmt(const ExprGroup &loop_vars,
190 const mesh::MeshPtr &mesh,
191 const mesh::MeshElementType &element_type,
192 Arch arch,
193 const ForLoopConfig &config);
194
195 FrontendForStmt(const Expr &loop_var,
196 const Expr &begin,
197 const Expr &end,
198 Arch arch,
199 const ForLoopConfig &config);
200
201 bool is_container_statement() const override {
202 return true;
203 }
204
205 TI_DEFINE_ACCEPT
206
207 private:
208 void init_config(Arch arch, const ForLoopConfig &config);
209
210 void init_loop_vars(const ExprGroup &loop_vars);
211
212 void add_loop_var(const Expr &loop_var);
213};
214
215class FrontendFuncDefStmt : public Stmt {
216 public:
217 std::string funcid;
218 std::unique_ptr<Block> body;
219
220 explicit FrontendFuncDefStmt(const std::string &funcid) : funcid(funcid) {
221 }
222
223 bool is_container_statement() const override {
224 return true;
225 }
226
227 TI_DEFINE_ACCEPT
228};
229
230class FrontendBreakStmt : public Stmt {
231 public:
232 FrontendBreakStmt() {
233 }
234
235 bool is_container_statement() const override {
236 return false;
237 }
238
239 TI_DEFINE_ACCEPT
240};
241
242class FrontendContinueStmt : public Stmt {
243 public:
244 FrontendContinueStmt() = default;
245
246 bool is_container_statement() const override {
247 return false;
248 }
249
250 TI_DEFINE_ACCEPT
251};
252
253class FrontendWhileStmt : public Stmt {
254 public:
255 Expr cond;
256 std::unique_ptr<Block> body;
257
258 explicit FrontendWhileStmt(const Expr &cond) : cond(cond) {
259 }
260
261 bool is_container_statement() const override {
262 return true;
263 }
264
265 TI_DEFINE_ACCEPT
266};
267
268class FrontendReturnStmt : public Stmt {
269 public:
270 ExprGroup values;
271
272 explicit FrontendReturnStmt(const ExprGroup &group);
273
274 bool is_container_statement() const override {
275 return false;
276 }
277
278 TI_DEFINE_ACCEPT
279};
280
281// Expressions
282
283class ArgLoadExpression : public Expression {
284 public:
285 int arg_id;
286 DataType dt;
287 bool is_ptr;
288
289 ArgLoadExpression(int arg_id, DataType dt, bool is_ptr = false)
290 : arg_id(arg_id), dt(dt), is_ptr(is_ptr) {
291 }
292
293 void type_check(const CompileConfig *config) override;
294
295 void flatten(FlattenContext *ctx) override;
296
297 bool is_lvalue() const override {
298 return is_ptr;
299 }
300
301 TI_DEFINE_ACCEPT_FOR_EXPRESSION
302};
303
304class Texture;
305
306class TexturePtrExpression : public Expression {
307 public:
308 int arg_id;
309 int num_dims;
310 bool is_storage{false};
311
312 // Optional, for storage textures
313 int num_channels{0};
314 DataType channel_format{PrimitiveType::f32};
315 int lod{0};
316
317 explicit TexturePtrExpression(int arg_id, int num_dims)
318 : arg_id(arg_id), num_dims(num_dims) {
319 }
320
321 TexturePtrExpression(int arg_id,
322 int num_dims,
323 int num_channels,
324 DataType channel_format,
325 int lod)
326 : arg_id(arg_id),
327 num_dims(num_dims),
328 is_storage(true),
329 num_channels(num_channels),
330 channel_format(channel_format),
331 lod(lod) {
332 }
333
334 void type_check(const CompileConfig *config) override;
335
336 void flatten(FlattenContext *ctx) override;
337
338 TI_DEFINE_ACCEPT_FOR_EXPRESSION
339};
340
341class RandExpression : public Expression {
342 public:
343 DataType dt;
344
345 explicit RandExpression(DataType dt) : dt(dt) {
346 }
347
348 void type_check(const CompileConfig *config) override;
349
350 void flatten(FlattenContext *ctx) override;
351
352 TI_DEFINE_ACCEPT_FOR_EXPRESSION
353};
354
355class UnaryOpExpression : public Expression {
356 public:
357 UnaryOpType type;
358 Expr operand;
359 DataType cast_type;
360
361 UnaryOpExpression(UnaryOpType type, const Expr &operand)
362 : type(type), operand(operand) {
363 cast_type = PrimitiveType::unknown;
364 }
365
366 UnaryOpExpression(UnaryOpType type, const Expr &operand, DataType cast_type)
367 : type(type), operand(operand), cast_type(cast_type) {
368 }
369
370 void type_check(const CompileConfig *config) override;
371
372 bool is_cast() const;
373
374 void flatten(FlattenContext *ctx) override;
375
376 TI_DEFINE_ACCEPT_FOR_EXPRESSION
377};
378
379class BinaryOpExpression : public Expression {
380 public:
381 BinaryOpType type;
382 Expr lhs, rhs;
383
384 BinaryOpExpression(const BinaryOpType &type, const Expr &lhs, const Expr &rhs)
385 : type(type), lhs(lhs), rhs(rhs) {
386 }
387
388 void type_check(const CompileConfig *config) override;
389
390 void flatten(FlattenContext *ctx) override;
391
392 TI_DEFINE_ACCEPT_FOR_EXPRESSION
393};
394
395class TernaryOpExpression : public Expression {
396 public:
397 TernaryOpType type;
398 Expr op1, op2, op3;
399
400 TernaryOpExpression(TernaryOpType type,
401 const Expr &op1,
402 const Expr &op2,
403 const Expr &op3)
404 : type(type) {
405 this->op1.set(op1);
406 this->op2.set(op2);
407 this->op3.set(op3);
408 }
409
410 void type_check(const CompileConfig *config) override;
411
412 void flatten(FlattenContext *ctx) override;
413
414 TI_DEFINE_ACCEPT_FOR_EXPRESSION
415};
416
417class InternalFuncCallExpression : public Expression {
418 public:
419 std::string func_name;
420 std::vector<Expr> args;
421 bool with_runtime_context;
422
423 InternalFuncCallExpression(const std::string &func_name,
424 const std::vector<Expr> &args_,
425 bool with_runtime_context)
426 : func_name(func_name), with_runtime_context(with_runtime_context) {
427 for (auto &a : args_) {
428 args.push_back(a);
429 }
430 }
431
432 void type_check(const CompileConfig *config) override;
433
434 void flatten(FlattenContext *ctx) override;
435
436 TI_DEFINE_ACCEPT_FOR_EXPRESSION
437};
438
439// TODO: Make this a non-expr
440class ExternalTensorExpression : public Expression {
441 public:
442 DataType dt;
443 int dim;
444 int arg_id;
445 int element_dim; // 0: scalar; 1: vector (SOA); 2: matrix (SOA); -1: vector
446 // (AOS); -2: matrix (AOS)
447 bool is_grad;
448
449 ExternalTensorExpression(const DataType &dt,
450 int dim,
451 int arg_id,
452 int element_dim,
453 bool is_grad = false) {
454 init(dt, dim, arg_id, element_dim, is_grad);
455 }
456
457 ExternalTensorExpression(const DataType &dt,
458 int dim,
459 int arg_id,
460 int element_dim,
461 const std::vector<int> &element_shape,
462 bool is_grad = false) {
463 if (element_shape.size() == 0) {
464 init(dt, dim, arg_id, element_dim, is_grad);
465 } else {
466 TI_ASSERT(dt->is<PrimitiveType>());
467
468 auto tensor_type =
469 taichi::lang::TypeFactory::get_instance().create_tensor_type(
470 element_shape, dt);
471 init(tensor_type, dim, arg_id, element_dim, is_grad);
472 }
473 }
474
475 explicit ExternalTensorExpression(Expr *expr, bool is_grad = true) {
476 auto ptr = expr->cast<ExternalTensorExpression>();
477 init(ptr->dt, ptr->dim, ptr->arg_id, ptr->element_dim, is_grad);
478 }
479
480 void flatten(FlattenContext *ctx) override;
481
482 TI_DEFINE_ACCEPT_FOR_EXPRESSION
483
484 const CompileConfig *get_compile_config() {
485 TI_ASSERT(config_ != nullptr);
486 return config_;
487 }
488
489 void type_check(const CompileConfig *config) override {
490 ret_type = dt;
491 config_ = config;
492 }
493
494 private:
495 const CompileConfig *config_ = nullptr;
496
497 void init(const DataType &dt,
498 int dim,
499 int arg_id,
500 int element_dim,
501 bool is_grad) {
502 this->dt = dt;
503 this->dim = dim;
504 this->arg_id = arg_id;
505 this->element_dim = element_dim;
506 this->is_grad = is_grad;
507 }
508};
509
510// TODO: Make this a non-expr
511class FieldExpression : public Expression {
512 public:
513 Identifier ident;
514 DataType dt;
515 std::string name;
516 SNode *snode{nullptr};
517 SNodeGradType snode_grad_type{SNodeGradType::kPrimal};
518 bool has_ambient{false};
519 TypedConstant ambient_value;
520 Expr adjoint;
521 Expr dual;
522 Expr adjoint_checkbit;
523
524 FieldExpression(DataType dt, const Identifier &ident) : ident(ident), dt(dt) {
525 }
526
527 void type_check(const CompileConfig *config) override {
528 }
529
530 void set_snode(SNode *snode) {
531 this->snode = snode;
532 }
533
534 TI_DEFINE_ACCEPT_FOR_EXPRESSION
535};
536
537class MatrixFieldExpression : public Expression {
538 public:
539 std::vector<Expr> fields;
540 std::vector<int> element_shape;
541 bool dynamic_indexable{false};
542 int dynamic_index_stride{0};
543
544 MatrixFieldExpression(const std::vector<Expr> &fields,
545 const std::vector<int> &element_shape)
546 : fields(fields), element_shape(element_shape) {
547 for (auto &field : fields) {
548 TI_ASSERT(field.is<FieldExpression>());
549 }
550 TI_ASSERT(!fields.empty());
551 auto compute_type =
552 fields[0].cast<FieldExpression>()->dt->get_compute_type();
553 for (auto &field : fields) {
554 if (field.cast<FieldExpression>()->dt->get_compute_type() !=
555 compute_type) {
556 throw TaichiRuntimeError(
557 "Member fields of a matrix field must have the same compute type");
558 }
559 }
560 }
561
562 void type_check(const CompileConfig *config) override {
563 }
564
565 TI_DEFINE_ACCEPT_FOR_EXPRESSION
566};
567
568/**
569 * Creating a local matrix;
570 * lowered from ti.Matrix
571 */
572class MatrixExpression : public Expression {
573 public:
574 std::vector<Expr> elements;
575 DataType dt;
576
577 MatrixExpression(const std::vector<Expr> &elements,
578 std::vector<int> shape,
579 DataType element_type)
580 : elements(elements) {
581 this->dt = DataType(TypeFactory::create_tensor_type(shape, element_type));
582 }
583
584 void type_check(const CompileConfig *config) override;
585
586 void flatten(FlattenContext *ctx) override;
587
588 TI_DEFINE_ACCEPT_FOR_EXPRESSION
589};
590
591class IndexExpression : public Expression {
592 public:
593 // `var` is one of FieldExpression, MatrixFieldExpression,
594 // ExternalTensorExpression, IdExpression
595 Expr var;
596 // In the cases of matrix slice and vector swizzle, there can be multiple
597 // indices, and the corresponding ret_shape should also be recorded. In normal
598 // index expressions ret_shape will be left empty.
599 std::vector<ExprGroup> indices_group;
600 std::vector<int> ret_shape;
601
602 IndexExpression(const Expr &var,
603 const ExprGroup &indices,
604 std::string tb = "");
605
606 IndexExpression(const Expr &var,
607 const std::vector<ExprGroup> &indices_group,
608 const std::vector<int> &ret_shape,
609 std::string tb = "");
610
611 void type_check(const CompileConfig *config) override;
612
613 void flatten(FlattenContext *ctx) override;
614
615 bool is_lvalue() const override {
616 return true;
617 }
618
619 // whether the LocalLoad/Store or GlobalLoad/Store is to be used on the
620 // compiled stmt
621 bool is_local() const;
622 bool is_global() const;
623
624 TI_DEFINE_ACCEPT_FOR_EXPRESSION
625
626 private:
627 bool is_field() const;
628 bool is_matrix_field() const;
629 bool is_ndarray() const;
630 bool is_tensor() const;
631};
632
633class RangeAssumptionExpression : public Expression {
634 public:
635 Expr input, base;
636 int low, high;
637
638 RangeAssumptionExpression(const Expr &input,
639 const Expr &base,
640 int low,
641 int high)
642 : input(input), base(base), low(low), high(high) {
643 }
644
645 void type_check(const CompileConfig *config) override;
646
647 void flatten(FlattenContext *ctx) override;
648
649 TI_DEFINE_ACCEPT_FOR_EXPRESSION
650};
651
652class LoopUniqueExpression : public Expression {
653 public:
654 Expr input;
655 std::vector<SNode *> covers;
656
657 LoopUniqueExpression(const Expr &input, const std::vector<SNode *> &covers)
658 : input(input), covers(covers) {
659 }
660
661 void type_check(const CompileConfig *config) override;
662
663 void flatten(FlattenContext *ctx) override;
664
665 TI_DEFINE_ACCEPT_FOR_EXPRESSION
666};
667
668class IdExpression : public Expression {
669 public:
670 Identifier id;
671
672 explicit IdExpression(const Identifier &id) : id(id) {
673 }
674
675 void type_check(const CompileConfig *config) override {
676 }
677
678 void flatten(FlattenContext *ctx) override;
679
680 Stmt *flatten_noload(FlattenContext *ctx) {
681 return ctx->current_block->lookup_var(id);
682 }
683
684 bool is_lvalue() const override {
685 return true;
686 }
687
688 TI_DEFINE_ACCEPT_FOR_EXPRESSION
689};
690
691// ti.atomic_*() is an expression with side effect.
692class AtomicOpExpression : public Expression {
693 public:
694 AtomicOpType op_type;
695 Expr dest, val;
696
697 AtomicOpExpression(AtomicOpType op_type, const Expr &dest, const Expr &val)
698 : op_type(op_type), dest(dest), val(val) {
699 }
700
701 void type_check(const CompileConfig *config) override;
702
703 void flatten(FlattenContext *ctx) override;
704
705 TI_DEFINE_ACCEPT_FOR_EXPRESSION
706};
707
708class SNodeOpExpression : public Expression {
709 public:
710 SNode *snode;
711 SNodeOpType op_type;
712 ExprGroup indices;
713 std::vector<Expr> values; // Only for op_type==append
714
715 SNodeOpExpression(SNode *snode,
716 SNodeOpType op_type,
717 const ExprGroup &indices);
718
719 SNodeOpExpression(SNode *snode,
720 SNodeOpType op_type,
721 const ExprGroup &indices,
722 const std::vector<Expr> &values);
723
724 void type_check(const CompileConfig *config) override;
725
726 void flatten(FlattenContext *ctx) override;
727
728 TI_DEFINE_ACCEPT_FOR_EXPRESSION
729};
730
731class TextureOpExpression : public Expression {
732 public:
733 TextureOpType op;
734 Expr texture_ptr;
735 ExprGroup args;
736
737 explicit TextureOpExpression(TextureOpType op,
738 Expr texture_ptr,
739 const ExprGroup &args);
740
741 void type_check(const CompileConfig *config) override;
742
743 void flatten(FlattenContext *ctx) override;
744
745 TI_DEFINE_ACCEPT_FOR_EXPRESSION
746};
747
748class ConstExpression : public Expression {
749 public:
750 TypedConstant val;
751
752 template <typename T>
753 explicit ConstExpression(const T &x) : val(x) {
754 ret_type = val.dt;
755 }
756 template <typename T>
757 ConstExpression(const DataType &dt, const T &x) : val({dt, x}) {
758 ret_type = dt;
759 }
760
761 void type_check(const CompileConfig *config) override;
762
763 void flatten(FlattenContext *ctx) override;
764
765 TI_DEFINE_ACCEPT_FOR_EXPRESSION
766};
767
768class ExternalTensorShapeAlongAxisExpression : public Expression {
769 public:
770 Expr ptr;
771 int axis;
772
773 ExternalTensorShapeAlongAxisExpression(const Expr &ptr, int axis)
774 : ptr(ptr), axis(axis) {
775 }
776
777 void type_check(const CompileConfig *config) override;
778
779 void flatten(FlattenContext *ctx) override;
780
781 TI_DEFINE_ACCEPT_FOR_EXPRESSION
782};
783
784class FrontendFuncCallStmt : public Stmt {
785 public:
786 std::optional<Identifier> ident;
787 Function *func;
788 ExprGroup args;
789
790 explicit FrontendFuncCallStmt(
791 Function *func,
792 const ExprGroup &args,
793 const std::optional<Identifier> &id = std::nullopt)
794 : ident(id), func(func), args(args) {
795 TI_ASSERT(id.has_value() == !func->rets.empty());
796 }
797
798 bool is_container_statement() const override {
799 return false;
800 }
801
802 TI_DEFINE_ACCEPT
803};
804
805class GetElementExpression : public Expression {
806 public:
807 Expr src;
808 std::vector<int> index;
809
810 void type_check(const CompileConfig *config) override;
811
812 GetElementExpression(const Expr &src, std::vector<int> index)
813 : src(src), index(index) {
814 }
815
816 void flatten(FlattenContext *ctx) override;
817
818 TI_DEFINE_ACCEPT_FOR_EXPRESSION
819};
820
821// Mesh related.
822
823class MeshPatchIndexExpression : public Expression {
824 public:
825 MeshPatchIndexExpression() {
826 }
827
828 void type_check(const CompileConfig *config) override;
829
830 void flatten(FlattenContext *ctx) override;
831
832 TI_DEFINE_ACCEPT_FOR_EXPRESSION
833};
834
835class MeshRelationAccessExpression : public Expression {
836 public:
837 mesh::Mesh *mesh;
838 Expr mesh_idx;
839 mesh::MeshElementType to_type;
840 Expr neighbor_idx;
841
842 void type_check(const CompileConfig *config) override;
843
844 MeshRelationAccessExpression(mesh::Mesh *mesh,
845 const Expr mesh_idx,
846 mesh::MeshElementType to_type)
847 : mesh(mesh), mesh_idx(mesh_idx), to_type(to_type) {
848 }
849
850 MeshRelationAccessExpression(mesh::Mesh *mesh,
851 const Expr mesh_idx,
852 mesh::MeshElementType to_type,
853 const Expr neighbor_idx)
854 : mesh(mesh),
855 mesh_idx(mesh_idx),
856 to_type(to_type),
857 neighbor_idx(neighbor_idx) {
858 }
859
860 void flatten(FlattenContext *ctx) override;
861
862 TI_DEFINE_ACCEPT_FOR_EXPRESSION
863};
864
865class MeshIndexConversionExpression : public Expression {
866 public:
867 mesh::Mesh *mesh;
868 mesh::MeshElementType idx_type;
869 Expr idx;
870 mesh::ConvType conv_type;
871
872 void type_check(const CompileConfig *config) override;
873
874 MeshIndexConversionExpression(mesh::Mesh *mesh,
875 mesh::MeshElementType idx_type,
876 const Expr idx,
877 mesh::ConvType conv_type);
878
879 void flatten(FlattenContext *ctx) override;
880
881 TI_DEFINE_ACCEPT_FOR_EXPRESSION
882};
883
884class ReferenceExpression : public Expression {
885 public:
886 Expr var;
887 void type_check(const CompileConfig *config) override;
888
889 explicit ReferenceExpression(const Expr &expr) : var(expr) {
890 }
891
892 void flatten(FlattenContext *ctx) override;
893
894 TI_DEFINE_ACCEPT_FOR_EXPRESSION
895};
896
897class ASTBuilder {
898 private:
899 enum LoopState { None, Outermost, Inner };
900 enum LoopType { NotLoop, For, While };
901
902 class ForLoopDecoratorRecorder {
903 public:
904 ForLoopConfig config;
905
906 ForLoopDecoratorRecorder() {
907 reset();
908 }
909
910 void reset() {
911 config.is_bit_vectorized = false;
912 config.num_cpu_threads = 0;
913 config.uniform = false;
914 config.mem_access_opt.clear();
915 config.block_dim = 0;
916 config.strictly_serialized = false;
917 }
918 };
919
920 std::vector<Block *> stack_;
921 std::vector<LoopState> loop_state_stack_;
922 Arch arch_;
923 ForLoopDecoratorRecorder for_loop_dec_;
924 int id_counter_{0};
925
926 public:
927 ASTBuilder(Block *initial, Arch arch) : arch_(arch) {
928 stack_.push_back(initial);
929 loop_state_stack_.push_back(None);
930 }
931
932 void insert(std::unique_ptr<Stmt> &&stmt, int location = -1);
933
934 Block *current_block();
935 Stmt *get_last_stmt();
936 void stop_gradient(SNode *);
937 void insert_assignment(Expr &lhs,
938 const Expr &rhs,
939 const std::string &tb = "");
940 Expr make_var(const Expr &x, std::string tb);
941 void insert_for(const Expr &s,
942 const Expr &e,
943 const std::function<void(Expr)> &func);
944
945 Expr make_id_expr(const std::string &name);
946 Expr make_matrix_expr(const std::vector<int> &shape,
947 const DataType &dt,
948 const std::vector<Expr> &elements);
949 Expr insert_thread_idx_expr();
950 Expr insert_patch_idx_expr();
951 void create_kernel_exprgroup_return(const ExprGroup &group);
952 void create_print(std::vector<std::variant<Expr, std::string>> contents);
953 void begin_func(const std::string &funcid);
954 void end_func(const std::string &funcid);
955 void begin_frontend_if(const Expr &cond);
956 void begin_frontend_if_true();
957 void begin_frontend_if_false();
958 void insert_external_func_call(std::size_t func_addr,
959 std::string source,
960 std::string filename,
961 std::string funcname,
962 const ExprGroup &args,
963 const ExprGroup &outputs);
964 Expr expr_alloca();
965 Expr expr_alloca_shared_array(const std::vector<int> &shape,
966 const DataType &element_type);
967 Expr expr_subscript(const Expr &expr,
968 const ExprGroup &indices,
969 std::string tb = "");
970
971 Expr mesh_index_conversion(mesh::MeshPtr mesh_ptr,
972 mesh::MeshElementType idx_type,
973 const Expr &idx,
974 mesh::ConvType &conv_type);
975
976 void expr_assign(const Expr &lhs, const Expr &rhs, std::string tb);
977 std::optional<Expr> insert_func_call(Function *func, const ExprGroup &args);
978 void create_assert_stmt(const Expr &cond,
979 const std::string &msg,
980 const std::vector<Expr> &args);
981 void begin_frontend_range_for(const Expr &i, const Expr &s, const Expr &e);
982 void begin_frontend_struct_for_on_snode(const ExprGroup &loop_vars,
983 SNode *snode);
984 void begin_frontend_struct_for_on_external_tensor(
985 const ExprGroup &loop_vars,
986 const Expr &external_tensor);
987 void begin_frontend_mesh_for(const Expr &i,
988 const mesh::MeshPtr &mesh_ptr,
989 const mesh::MeshElementType &element_type);
990 void begin_frontend_while(const Expr &cond);
991 void insert_break_stmt();
992 void insert_continue_stmt();
993 void insert_expr_stmt(const Expr &val);
994 void insert_snode_activate(SNode *snode, const ExprGroup &expr_group);
995 void insert_snode_deactivate(SNode *snode, const ExprGroup &expr_group);
996 Expr make_texture_op_expr(const TextureOpType &op,
997 const Expr &texture_ptr,
998 const ExprGroup &args);
999 /*
1000 * This function allocates the space for a new item (a struct or a scalar)
1001 * in the Dynamic SNode, and assigns values to the elements inside it.
1002 *
1003 * When appending a struct, the size of vals must be equal to
1004 * the number of elements in the struct. When appending a scalar,
1005 * the size of vals must be one.
1006 */
1007 Expr snode_append(SNode *snode,
1008 const ExprGroup &indices,
1009 const std::vector<Expr> &vals);
1010 Expr snode_is_active(SNode *snode, const ExprGroup &indices);
1011 Expr snode_length(SNode *snode, const ExprGroup &indices);
1012 Expr snode_get_addr(SNode *snode, const ExprGroup &indices);
1013
1014 std::vector<Expr> expand_exprs(const std::vector<Expr> &exprs);
1015
1016 void create_scope(std::unique_ptr<Block> &list, LoopType tp = NotLoop);
1017 void pop_scope();
1018
1019 void bit_vectorize() {
1020 for_loop_dec_.config.is_bit_vectorized = true;
1021 }
1022
1023 void parallelize(int v) {
1024 for_loop_dec_.config.num_cpu_threads = v;
1025 }
1026
1027 void strictly_serialize() {
1028 for_loop_dec_.config.strictly_serialized = true;
1029 }
1030
1031 void block_dim(int v) {
1032 if (arch_ == Arch::cuda || arch_ == Arch::vulkan || arch_ == Arch::amdgpu) {
1033 TI_ASSERT((v % 32 == 0) || bit::is_power_of_two(v));
1034 } else {
1035 TI_ASSERT(bit::is_power_of_two(v));
1036 }
1037 for_loop_dec_.config.block_dim = v;
1038 }
1039
1040 void insert_snode_access_flag(SNodeAccessFlag v, const Expr &field) {
1041 for_loop_dec_.config.mem_access_opt.add_flag(field.snode(), v);
1042 }
1043
1044 void reset_snode_access_flag() {
1045 for_loop_dec_.reset();
1046 }
1047
1048 Identifier get_next_id(const std::string &name = "") {
1049 return Identifier(id_counter_++, name);
1050 }
1051};
1052
1053class FrontendContext {
1054 private:
1055 std::unique_ptr<ASTBuilder> current_builder_;
1056 std::unique_ptr<Block> root_node_;
1057
1058 public:
1059 explicit FrontendContext(Arch arch) {
1060 root_node_ = std::make_unique<Block>();
1061 current_builder_ = std::make_unique<ASTBuilder>(root_node_.get(), arch);
1062 }
1063
1064 ASTBuilder &builder() {
1065 return *current_builder_;
1066 }
1067
1068 std::unique_ptr<Block> get_root() {
1069 return std::move(root_node_);
1070 }
1071};
1072
1073Stmt *flatten_lvalue(Expr expr, Expression::FlattenContext *ctx);
1074
1075Stmt *flatten_rvalue(Expr expr, Expression::FlattenContext *ctx);
1076
1077} // namespace taichi::lang
1078