1 | #pragma once |
2 | |
3 | #include <string> |
4 | #include <vector> |
5 | |
6 | #include "taichi/ir/snode_types.h" |
7 | #include "taichi/ir/stmt_op_types.h" |
8 | #include "taichi/ir/ir.h" |
9 | #include "taichi/ir/expression.h" |
10 | #include "taichi/rhi/arch.h" |
11 | #include "taichi/program/function.h" |
12 | #include "taichi/ir/mesh.h" |
13 | |
14 | namespace taichi::lang { |
15 | |
16 | class ASTBuilder; |
17 | |
18 | struct ForLoopConfig { |
19 | bool is_bit_vectorized{false}; |
20 | int num_cpu_threads{0}; |
21 | bool strictly_serialized{false}; |
22 | MemoryAccessOptions mem_access_opt; |
23 | int block_dim{0}; |
24 | bool uniform{false}; |
25 | }; |
26 | |
27 | // Frontend Statements |
28 | class FrontendExternalFuncStmt : public Stmt { |
29 | public: |
30 | void *so_func; |
31 | std::string asm_source; |
32 | std::string bc_filename; |
33 | std::string bc_funcname; |
34 | std::vector<Expr> args; |
35 | std::vector<Expr> outputs; |
36 | |
37 | FrontendExternalFuncStmt(void *so_func, |
38 | const std::string &asm_source, |
39 | const std::string &bc_filename, |
40 | const std::string &bc_funcname, |
41 | const std::vector<Expr> &args, |
42 | const std::vector<Expr> &outputs) |
43 | : so_func(so_func), |
44 | asm_source(asm_source), |
45 | bc_filename(bc_filename), |
46 | bc_funcname(bc_funcname), |
47 | args(args), |
48 | outputs(outputs) { |
49 | } |
50 | |
51 | TI_DEFINE_ACCEPT |
52 | }; |
53 | |
54 | class FrontendExprStmt : public Stmt { |
55 | public: |
56 | Expr val; |
57 | |
58 | explicit FrontendExprStmt(const Expr &val) : val(val) { |
59 | } |
60 | |
61 | TI_DEFINE_ACCEPT |
62 | }; |
63 | |
64 | class FrontendAllocaStmt : public Stmt { |
65 | public: |
66 | Identifier ident; |
67 | |
68 | FrontendAllocaStmt(const Identifier &lhs, DataType type) |
69 | : ident(lhs), is_shared(false) { |
70 | ret_type = type; |
71 | } |
72 | |
73 | FrontendAllocaStmt(const Identifier &lhs, |
74 | std::vector<int> shape, |
75 | DataType element, |
76 | bool is_shared = false) |
77 | : ident(lhs), is_shared(is_shared) { |
78 | ret_type = DataType(TypeFactory::create_tensor_type(shape, element)); |
79 | } |
80 | |
81 | bool is_shared; |
82 | |
83 | TI_DEFINE_ACCEPT |
84 | }; |
85 | |
86 | class FrontendSNodeOpStmt : public Stmt { |
87 | public: |
88 | SNodeOpType op_type; |
89 | SNode *snode; |
90 | ExprGroup indices; |
91 | Expr val; |
92 | |
93 | FrontendSNodeOpStmt(SNodeOpType op_type, |
94 | SNode *snode, |
95 | const ExprGroup &indices, |
96 | const Expr &val = Expr(nullptr)); |
97 | |
98 | TI_DEFINE_ACCEPT |
99 | }; |
100 | |
101 | class FrontendAssertStmt : public Stmt { |
102 | public: |
103 | std::string text; |
104 | Expr cond; |
105 | std::vector<Expr> args; |
106 | |
107 | FrontendAssertStmt(const Expr &cond, const std::string &text) |
108 | : text(text), cond(cond) { |
109 | } |
110 | |
111 | FrontendAssertStmt(const Expr &cond, |
112 | const std::string &text, |
113 | const std::vector<Expr> &args_) |
114 | : text(text), cond(cond) { |
115 | for (auto &a : args_) { |
116 | args.push_back(a); |
117 | } |
118 | } |
119 | |
120 | TI_DEFINE_ACCEPT |
121 | }; |
122 | |
123 | class FrontendAssignStmt : public Stmt { |
124 | public: |
125 | Expr lhs, rhs; |
126 | |
127 | FrontendAssignStmt(const Expr &lhs, const Expr &rhs); |
128 | |
129 | TI_DEFINE_ACCEPT |
130 | }; |
131 | |
132 | class FrontendIfStmt : public Stmt { |
133 | public: |
134 | Expr condition; |
135 | std::unique_ptr<Block> true_statements, false_statements; |
136 | |
137 | explicit FrontendIfStmt(const Expr &condition) : condition(condition) { |
138 | } |
139 | |
140 | bool is_container_statement() const override { |
141 | return true; |
142 | } |
143 | |
144 | TI_DEFINE_ACCEPT |
145 | }; |
146 | |
147 | class FrontendPrintStmt : public Stmt { |
148 | public: |
149 | using EntryType = std::variant<Expr, std::string>; |
150 | std::vector<EntryType> contents; |
151 | |
152 | explicit FrontendPrintStmt(const std::vector<EntryType> &contents_) { |
153 | for (const auto &c : contents_) { |
154 | if (std::holds_alternative<Expr>(c)) |
155 | contents.push_back(std::get<Expr>(c)); |
156 | else |
157 | contents.push_back(c); |
158 | } |
159 | } |
160 | |
161 | TI_DEFINE_ACCEPT |
162 | }; |
163 | |
164 | class FrontendForStmt : public Stmt { |
165 | public: |
166 | SNode *snode{nullptr}; |
167 | Expr external_tensor; |
168 | mesh::Mesh *mesh{nullptr}; |
169 | mesh::MeshElementType element_type; |
170 | Expr begin, end; |
171 | std::unique_ptr<Block> body; |
172 | std::vector<Identifier> loop_var_ids; |
173 | bool is_bit_vectorized; |
174 | int num_cpu_threads; |
175 | bool strictly_serialized; |
176 | MemoryAccessOptions mem_access_opt; |
177 | int block_dim; |
178 | |
179 | FrontendForStmt(const ExprGroup &loop_vars, |
180 | SNode *snode, |
181 | Arch arch, |
182 | const ForLoopConfig &config); |
183 | |
184 | FrontendForStmt(const ExprGroup &loop_vars, |
185 | const Expr &external_tensor, |
186 | Arch arch, |
187 | const ForLoopConfig &config); |
188 | |
189 | FrontendForStmt(const ExprGroup &loop_vars, |
190 | const mesh::MeshPtr &mesh, |
191 | const mesh::MeshElementType &element_type, |
192 | Arch arch, |
193 | const ForLoopConfig &config); |
194 | |
195 | FrontendForStmt(const Expr &loop_var, |
196 | const Expr &begin, |
197 | const Expr &end, |
198 | Arch arch, |
199 | const ForLoopConfig &config); |
200 | |
201 | bool is_container_statement() const override { |
202 | return true; |
203 | } |
204 | |
205 | TI_DEFINE_ACCEPT |
206 | |
207 | private: |
208 | void init_config(Arch arch, const ForLoopConfig &config); |
209 | |
210 | void init_loop_vars(const ExprGroup &loop_vars); |
211 | |
212 | void add_loop_var(const Expr &loop_var); |
213 | }; |
214 | |
215 | class FrontendFuncDefStmt : public Stmt { |
216 | public: |
217 | std::string funcid; |
218 | std::unique_ptr<Block> body; |
219 | |
220 | explicit FrontendFuncDefStmt(const std::string &funcid) : funcid(funcid) { |
221 | } |
222 | |
223 | bool is_container_statement() const override { |
224 | return true; |
225 | } |
226 | |
227 | TI_DEFINE_ACCEPT |
228 | }; |
229 | |
230 | class FrontendBreakStmt : public Stmt { |
231 | public: |
232 | FrontendBreakStmt() { |
233 | } |
234 | |
235 | bool is_container_statement() const override { |
236 | return false; |
237 | } |
238 | |
239 | TI_DEFINE_ACCEPT |
240 | }; |
241 | |
242 | class FrontendContinueStmt : public Stmt { |
243 | public: |
244 | FrontendContinueStmt() = default; |
245 | |
246 | bool is_container_statement() const override { |
247 | return false; |
248 | } |
249 | |
250 | TI_DEFINE_ACCEPT |
251 | }; |
252 | |
253 | class FrontendWhileStmt : public Stmt { |
254 | public: |
255 | Expr cond; |
256 | std::unique_ptr<Block> body; |
257 | |
258 | explicit FrontendWhileStmt(const Expr &cond) : cond(cond) { |
259 | } |
260 | |
261 | bool is_container_statement() const override { |
262 | return true; |
263 | } |
264 | |
265 | TI_DEFINE_ACCEPT |
266 | }; |
267 | |
268 | class FrontendReturnStmt : public Stmt { |
269 | public: |
270 | ExprGroup values; |
271 | |
272 | explicit FrontendReturnStmt(const ExprGroup &group); |
273 | |
274 | bool is_container_statement() const override { |
275 | return false; |
276 | } |
277 | |
278 | TI_DEFINE_ACCEPT |
279 | }; |
280 | |
281 | // Expressions |
282 | |
283 | class ArgLoadExpression : public Expression { |
284 | public: |
285 | int arg_id; |
286 | DataType dt; |
287 | bool is_ptr; |
288 | |
289 | ArgLoadExpression(int arg_id, DataType dt, bool is_ptr = false) |
290 | : arg_id(arg_id), dt(dt), is_ptr(is_ptr) { |
291 | } |
292 | |
293 | void type_check(const CompileConfig *config) override; |
294 | |
295 | void flatten(FlattenContext *ctx) override; |
296 | |
297 | bool is_lvalue() const override { |
298 | return is_ptr; |
299 | } |
300 | |
301 | TI_DEFINE_ACCEPT_FOR_EXPRESSION |
302 | }; |
303 | |
304 | class Texture; |
305 | |
306 | class TexturePtrExpression : public Expression { |
307 | public: |
308 | int arg_id; |
309 | int num_dims; |
310 | bool is_storage{false}; |
311 | |
312 | // Optional, for storage textures |
313 | int num_channels{0}; |
314 | DataType channel_format{PrimitiveType::f32}; |
315 | int lod{0}; |
316 | |
317 | explicit TexturePtrExpression(int arg_id, int num_dims) |
318 | : arg_id(arg_id), num_dims(num_dims) { |
319 | } |
320 | |
321 | TexturePtrExpression(int arg_id, |
322 | int num_dims, |
323 | int num_channels, |
324 | DataType channel_format, |
325 | int lod) |
326 | : arg_id(arg_id), |
327 | num_dims(num_dims), |
328 | is_storage(true), |
329 | num_channels(num_channels), |
330 | channel_format(channel_format), |
331 | lod(lod) { |
332 | } |
333 | |
334 | void type_check(const CompileConfig *config) override; |
335 | |
336 | void flatten(FlattenContext *ctx) override; |
337 | |
338 | TI_DEFINE_ACCEPT_FOR_EXPRESSION |
339 | }; |
340 | |
341 | class RandExpression : public Expression { |
342 | public: |
343 | DataType dt; |
344 | |
345 | explicit RandExpression(DataType dt) : dt(dt) { |
346 | } |
347 | |
348 | void type_check(const CompileConfig *config) override; |
349 | |
350 | void flatten(FlattenContext *ctx) override; |
351 | |
352 | TI_DEFINE_ACCEPT_FOR_EXPRESSION |
353 | }; |
354 | |
355 | class UnaryOpExpression : public Expression { |
356 | public: |
357 | UnaryOpType type; |
358 | Expr operand; |
359 | DataType cast_type; |
360 | |
361 | UnaryOpExpression(UnaryOpType type, const Expr &operand) |
362 | : type(type), operand(operand) { |
363 | cast_type = PrimitiveType::unknown; |
364 | } |
365 | |
366 | UnaryOpExpression(UnaryOpType type, const Expr &operand, DataType cast_type) |
367 | : type(type), operand(operand), cast_type(cast_type) { |
368 | } |
369 | |
370 | void type_check(const CompileConfig *config) override; |
371 | |
372 | bool is_cast() const; |
373 | |
374 | void flatten(FlattenContext *ctx) override; |
375 | |
376 | TI_DEFINE_ACCEPT_FOR_EXPRESSION |
377 | }; |
378 | |
379 | class BinaryOpExpression : public Expression { |
380 | public: |
381 | BinaryOpType type; |
382 | Expr lhs, rhs; |
383 | |
384 | BinaryOpExpression(const BinaryOpType &type, const Expr &lhs, const Expr &rhs) |
385 | : type(type), lhs(lhs), rhs(rhs) { |
386 | } |
387 | |
388 | void type_check(const CompileConfig *config) override; |
389 | |
390 | void flatten(FlattenContext *ctx) override; |
391 | |
392 | TI_DEFINE_ACCEPT_FOR_EXPRESSION |
393 | }; |
394 | |
395 | class TernaryOpExpression : public Expression { |
396 | public: |
397 | TernaryOpType type; |
398 | Expr op1, op2, op3; |
399 | |
400 | TernaryOpExpression(TernaryOpType type, |
401 | const Expr &op1, |
402 | const Expr &op2, |
403 | const Expr &op3) |
404 | : type(type) { |
405 | this->op1.set(op1); |
406 | this->op2.set(op2); |
407 | this->op3.set(op3); |
408 | } |
409 | |
410 | void type_check(const CompileConfig *config) override; |
411 | |
412 | void flatten(FlattenContext *ctx) override; |
413 | |
414 | TI_DEFINE_ACCEPT_FOR_EXPRESSION |
415 | }; |
416 | |
417 | class InternalFuncCallExpression : public Expression { |
418 | public: |
419 | std::string func_name; |
420 | std::vector<Expr> args; |
421 | bool with_runtime_context; |
422 | |
423 | InternalFuncCallExpression(const std::string &func_name, |
424 | const std::vector<Expr> &args_, |
425 | bool with_runtime_context) |
426 | : func_name(func_name), with_runtime_context(with_runtime_context) { |
427 | for (auto &a : args_) { |
428 | args.push_back(a); |
429 | } |
430 | } |
431 | |
432 | void type_check(const CompileConfig *config) override; |
433 | |
434 | void flatten(FlattenContext *ctx) override; |
435 | |
436 | TI_DEFINE_ACCEPT_FOR_EXPRESSION |
437 | }; |
438 | |
439 | // TODO: Make this a non-expr |
440 | class ExternalTensorExpression : public Expression { |
441 | public: |
442 | DataType dt; |
443 | int dim; |
444 | int arg_id; |
445 | int element_dim; // 0: scalar; 1: vector (SOA); 2: matrix (SOA); -1: vector |
446 | // (AOS); -2: matrix (AOS) |
447 | bool is_grad; |
448 | |
449 | ExternalTensorExpression(const DataType &dt, |
450 | int dim, |
451 | int arg_id, |
452 | int element_dim, |
453 | bool is_grad = false) { |
454 | init(dt, dim, arg_id, element_dim, is_grad); |
455 | } |
456 | |
457 | ExternalTensorExpression(const DataType &dt, |
458 | int dim, |
459 | int arg_id, |
460 | int element_dim, |
461 | const std::vector<int> &element_shape, |
462 | bool is_grad = false) { |
463 | if (element_shape.size() == 0) { |
464 | init(dt, dim, arg_id, element_dim, is_grad); |
465 | } else { |
466 | TI_ASSERT(dt->is<PrimitiveType>()); |
467 | |
468 | auto tensor_type = |
469 | taichi::lang::TypeFactory::get_instance().create_tensor_type( |
470 | element_shape, dt); |
471 | init(tensor_type, dim, arg_id, element_dim, is_grad); |
472 | } |
473 | } |
474 | |
475 | explicit ExternalTensorExpression(Expr *expr, bool is_grad = true) { |
476 | auto ptr = expr->cast<ExternalTensorExpression>(); |
477 | init(ptr->dt, ptr->dim, ptr->arg_id, ptr->element_dim, is_grad); |
478 | } |
479 | |
480 | void flatten(FlattenContext *ctx) override; |
481 | |
482 | TI_DEFINE_ACCEPT_FOR_EXPRESSION |
483 | |
484 | const CompileConfig *get_compile_config() { |
485 | TI_ASSERT(config_ != nullptr); |
486 | return config_; |
487 | } |
488 | |
489 | void type_check(const CompileConfig *config) override { |
490 | ret_type = dt; |
491 | config_ = config; |
492 | } |
493 | |
494 | private: |
495 | const CompileConfig *config_ = nullptr; |
496 | |
497 | void init(const DataType &dt, |
498 | int dim, |
499 | int arg_id, |
500 | int element_dim, |
501 | bool is_grad) { |
502 | this->dt = dt; |
503 | this->dim = dim; |
504 | this->arg_id = arg_id; |
505 | this->element_dim = element_dim; |
506 | this->is_grad = is_grad; |
507 | } |
508 | }; |
509 | |
510 | // TODO: Make this a non-expr |
511 | class FieldExpression : public Expression { |
512 | public: |
513 | Identifier ident; |
514 | DataType dt; |
515 | std::string name; |
516 | SNode *snode{nullptr}; |
517 | SNodeGradType snode_grad_type{SNodeGradType::kPrimal}; |
518 | bool has_ambient{false}; |
519 | TypedConstant ambient_value; |
520 | Expr adjoint; |
521 | Expr dual; |
522 | Expr adjoint_checkbit; |
523 | |
524 | FieldExpression(DataType dt, const Identifier &ident) : ident(ident), dt(dt) { |
525 | } |
526 | |
527 | void type_check(const CompileConfig *config) override { |
528 | } |
529 | |
530 | void set_snode(SNode *snode) { |
531 | this->snode = snode; |
532 | } |
533 | |
534 | TI_DEFINE_ACCEPT_FOR_EXPRESSION |
535 | }; |
536 | |
537 | class MatrixFieldExpression : public Expression { |
538 | public: |
539 | std::vector<Expr> fields; |
540 | std::vector<int> element_shape; |
541 | bool dynamic_indexable{false}; |
542 | int dynamic_index_stride{0}; |
543 | |
544 | MatrixFieldExpression(const std::vector<Expr> &fields, |
545 | const std::vector<int> &element_shape) |
546 | : fields(fields), element_shape(element_shape) { |
547 | for (auto &field : fields) { |
548 | TI_ASSERT(field.is<FieldExpression>()); |
549 | } |
550 | TI_ASSERT(!fields.empty()); |
551 | auto compute_type = |
552 | fields[0].cast<FieldExpression>()->dt->get_compute_type(); |
553 | for (auto &field : fields) { |
554 | if (field.cast<FieldExpression>()->dt->get_compute_type() != |
555 | compute_type) { |
556 | throw TaichiRuntimeError( |
557 | "Member fields of a matrix field must have the same compute type" ); |
558 | } |
559 | } |
560 | } |
561 | |
562 | void type_check(const CompileConfig *config) override { |
563 | } |
564 | |
565 | TI_DEFINE_ACCEPT_FOR_EXPRESSION |
566 | }; |
567 | |
568 | /** |
569 | * Creating a local matrix; |
570 | * lowered from ti.Matrix |
571 | */ |
572 | class MatrixExpression : public Expression { |
573 | public: |
574 | std::vector<Expr> elements; |
575 | DataType dt; |
576 | |
577 | MatrixExpression(const std::vector<Expr> &elements, |
578 | std::vector<int> shape, |
579 | DataType element_type) |
580 | : elements(elements) { |
581 | this->dt = DataType(TypeFactory::create_tensor_type(shape, element_type)); |
582 | } |
583 | |
584 | void type_check(const CompileConfig *config) override; |
585 | |
586 | void flatten(FlattenContext *ctx) override; |
587 | |
588 | TI_DEFINE_ACCEPT_FOR_EXPRESSION |
589 | }; |
590 | |
591 | class IndexExpression : public Expression { |
592 | public: |
593 | // `var` is one of FieldExpression, MatrixFieldExpression, |
594 | // ExternalTensorExpression, IdExpression |
595 | Expr var; |
596 | // In the cases of matrix slice and vector swizzle, there can be multiple |
597 | // indices, and the corresponding ret_shape should also be recorded. In normal |
598 | // index expressions ret_shape will be left empty. |
599 | std::vector<ExprGroup> indices_group; |
600 | std::vector<int> ret_shape; |
601 | |
602 | IndexExpression(const Expr &var, |
603 | const ExprGroup &indices, |
604 | std::string tb = "" ); |
605 | |
606 | IndexExpression(const Expr &var, |
607 | const std::vector<ExprGroup> &indices_group, |
608 | const std::vector<int> &ret_shape, |
609 | std::string tb = "" ); |
610 | |
611 | void type_check(const CompileConfig *config) override; |
612 | |
613 | void flatten(FlattenContext *ctx) override; |
614 | |
615 | bool is_lvalue() const override { |
616 | return true; |
617 | } |
618 | |
619 | // whether the LocalLoad/Store or GlobalLoad/Store is to be used on the |
620 | // compiled stmt |
621 | bool is_local() const; |
622 | bool is_global() const; |
623 | |
624 | TI_DEFINE_ACCEPT_FOR_EXPRESSION |
625 | |
626 | private: |
627 | bool is_field() const; |
628 | bool is_matrix_field() const; |
629 | bool is_ndarray() const; |
630 | bool is_tensor() const; |
631 | }; |
632 | |
633 | class RangeAssumptionExpression : public Expression { |
634 | public: |
635 | Expr input, base; |
636 | int low, high; |
637 | |
638 | RangeAssumptionExpression(const Expr &input, |
639 | const Expr &base, |
640 | int low, |
641 | int high) |
642 | : input(input), base(base), low(low), high(high) { |
643 | } |
644 | |
645 | void type_check(const CompileConfig *config) override; |
646 | |
647 | void flatten(FlattenContext *ctx) override; |
648 | |
649 | TI_DEFINE_ACCEPT_FOR_EXPRESSION |
650 | }; |
651 | |
652 | class LoopUniqueExpression : public Expression { |
653 | public: |
654 | Expr input; |
655 | std::vector<SNode *> covers; |
656 | |
657 | LoopUniqueExpression(const Expr &input, const std::vector<SNode *> &covers) |
658 | : input(input), covers(covers) { |
659 | } |
660 | |
661 | void type_check(const CompileConfig *config) override; |
662 | |
663 | void flatten(FlattenContext *ctx) override; |
664 | |
665 | TI_DEFINE_ACCEPT_FOR_EXPRESSION |
666 | }; |
667 | |
668 | class IdExpression : public Expression { |
669 | public: |
670 | Identifier id; |
671 | |
672 | explicit IdExpression(const Identifier &id) : id(id) { |
673 | } |
674 | |
675 | void type_check(const CompileConfig *config) override { |
676 | } |
677 | |
678 | void flatten(FlattenContext *ctx) override; |
679 | |
680 | Stmt *flatten_noload(FlattenContext *ctx) { |
681 | return ctx->current_block->lookup_var(id); |
682 | } |
683 | |
684 | bool is_lvalue() const override { |
685 | return true; |
686 | } |
687 | |
688 | TI_DEFINE_ACCEPT_FOR_EXPRESSION |
689 | }; |
690 | |
691 | // ti.atomic_*() is an expression with side effect. |
692 | class AtomicOpExpression : public Expression { |
693 | public: |
694 | AtomicOpType op_type; |
695 | Expr dest, val; |
696 | |
697 | AtomicOpExpression(AtomicOpType op_type, const Expr &dest, const Expr &val) |
698 | : op_type(op_type), dest(dest), val(val) { |
699 | } |
700 | |
701 | void type_check(const CompileConfig *config) override; |
702 | |
703 | void flatten(FlattenContext *ctx) override; |
704 | |
705 | TI_DEFINE_ACCEPT_FOR_EXPRESSION |
706 | }; |
707 | |
708 | class SNodeOpExpression : public Expression { |
709 | public: |
710 | SNode *snode; |
711 | SNodeOpType op_type; |
712 | ExprGroup indices; |
713 | std::vector<Expr> values; // Only for op_type==append |
714 | |
715 | SNodeOpExpression(SNode *snode, |
716 | SNodeOpType op_type, |
717 | const ExprGroup &indices); |
718 | |
719 | SNodeOpExpression(SNode *snode, |
720 | SNodeOpType op_type, |
721 | const ExprGroup &indices, |
722 | const std::vector<Expr> &values); |
723 | |
724 | void type_check(const CompileConfig *config) override; |
725 | |
726 | void flatten(FlattenContext *ctx) override; |
727 | |
728 | TI_DEFINE_ACCEPT_FOR_EXPRESSION |
729 | }; |
730 | |
731 | class TextureOpExpression : public Expression { |
732 | public: |
733 | TextureOpType op; |
734 | Expr texture_ptr; |
735 | ExprGroup args; |
736 | |
737 | explicit TextureOpExpression(TextureOpType op, |
738 | Expr texture_ptr, |
739 | const ExprGroup &args); |
740 | |
741 | void type_check(const CompileConfig *config) override; |
742 | |
743 | void flatten(FlattenContext *ctx) override; |
744 | |
745 | TI_DEFINE_ACCEPT_FOR_EXPRESSION |
746 | }; |
747 | |
748 | class ConstExpression : public Expression { |
749 | public: |
750 | TypedConstant val; |
751 | |
752 | template <typename T> |
753 | explicit ConstExpression(const T &x) : val(x) { |
754 | ret_type = val.dt; |
755 | } |
756 | template <typename T> |
757 | ConstExpression(const DataType &dt, const T &x) : val({dt, x}) { |
758 | ret_type = dt; |
759 | } |
760 | |
761 | void type_check(const CompileConfig *config) override; |
762 | |
763 | void flatten(FlattenContext *ctx) override; |
764 | |
765 | TI_DEFINE_ACCEPT_FOR_EXPRESSION |
766 | }; |
767 | |
768 | class ExternalTensorShapeAlongAxisExpression : public Expression { |
769 | public: |
770 | Expr ptr; |
771 | int axis; |
772 | |
773 | ExternalTensorShapeAlongAxisExpression(const Expr &ptr, int axis) |
774 | : ptr(ptr), axis(axis) { |
775 | } |
776 | |
777 | void type_check(const CompileConfig *config) override; |
778 | |
779 | void flatten(FlattenContext *ctx) override; |
780 | |
781 | TI_DEFINE_ACCEPT_FOR_EXPRESSION |
782 | }; |
783 | |
784 | class FrontendFuncCallStmt : public Stmt { |
785 | public: |
786 | std::optional<Identifier> ident; |
787 | Function *func; |
788 | ExprGroup args; |
789 | |
790 | explicit FrontendFuncCallStmt( |
791 | Function *func, |
792 | const ExprGroup &args, |
793 | const std::optional<Identifier> &id = std::nullopt) |
794 | : ident(id), func(func), args(args) { |
795 | TI_ASSERT(id.has_value() == !func->rets.empty()); |
796 | } |
797 | |
798 | bool is_container_statement() const override { |
799 | return false; |
800 | } |
801 | |
802 | TI_DEFINE_ACCEPT |
803 | }; |
804 | |
805 | class GetElementExpression : public Expression { |
806 | public: |
807 | Expr src; |
808 | std::vector<int> index; |
809 | |
810 | void type_check(const CompileConfig *config) override; |
811 | |
812 | GetElementExpression(const Expr &src, std::vector<int> index) |
813 | : src(src), index(index) { |
814 | } |
815 | |
816 | void flatten(FlattenContext *ctx) override; |
817 | |
818 | TI_DEFINE_ACCEPT_FOR_EXPRESSION |
819 | }; |
820 | |
821 | // Mesh related. |
822 | |
823 | class MeshPatchIndexExpression : public Expression { |
824 | public: |
825 | MeshPatchIndexExpression() { |
826 | } |
827 | |
828 | void type_check(const CompileConfig *config) override; |
829 | |
830 | void flatten(FlattenContext *ctx) override; |
831 | |
832 | TI_DEFINE_ACCEPT_FOR_EXPRESSION |
833 | }; |
834 | |
835 | class MeshRelationAccessExpression : public Expression { |
836 | public: |
837 | mesh::Mesh *mesh; |
838 | Expr mesh_idx; |
839 | mesh::MeshElementType to_type; |
840 | Expr neighbor_idx; |
841 | |
842 | void type_check(const CompileConfig *config) override; |
843 | |
844 | MeshRelationAccessExpression(mesh::Mesh *mesh, |
845 | const Expr mesh_idx, |
846 | mesh::MeshElementType to_type) |
847 | : mesh(mesh), mesh_idx(mesh_idx), to_type(to_type) { |
848 | } |
849 | |
850 | MeshRelationAccessExpression(mesh::Mesh *mesh, |
851 | const Expr mesh_idx, |
852 | mesh::MeshElementType to_type, |
853 | const Expr neighbor_idx) |
854 | : mesh(mesh), |
855 | mesh_idx(mesh_idx), |
856 | to_type(to_type), |
857 | neighbor_idx(neighbor_idx) { |
858 | } |
859 | |
860 | void flatten(FlattenContext *ctx) override; |
861 | |
862 | TI_DEFINE_ACCEPT_FOR_EXPRESSION |
863 | }; |
864 | |
865 | class MeshIndexConversionExpression : public Expression { |
866 | public: |
867 | mesh::Mesh *mesh; |
868 | mesh::MeshElementType idx_type; |
869 | Expr idx; |
870 | mesh::ConvType conv_type; |
871 | |
872 | void type_check(const CompileConfig *config) override; |
873 | |
874 | MeshIndexConversionExpression(mesh::Mesh *mesh, |
875 | mesh::MeshElementType idx_type, |
876 | const Expr idx, |
877 | mesh::ConvType conv_type); |
878 | |
879 | void flatten(FlattenContext *ctx) override; |
880 | |
881 | TI_DEFINE_ACCEPT_FOR_EXPRESSION |
882 | }; |
883 | |
884 | class ReferenceExpression : public Expression { |
885 | public: |
886 | Expr var; |
887 | void type_check(const CompileConfig *config) override; |
888 | |
889 | explicit ReferenceExpression(const Expr &expr) : var(expr) { |
890 | } |
891 | |
892 | void flatten(FlattenContext *ctx) override; |
893 | |
894 | TI_DEFINE_ACCEPT_FOR_EXPRESSION |
895 | }; |
896 | |
897 | class ASTBuilder { |
898 | private: |
899 | enum LoopState { None, Outermost, Inner }; |
900 | enum LoopType { NotLoop, For, While }; |
901 | |
902 | class ForLoopDecoratorRecorder { |
903 | public: |
904 | ForLoopConfig config; |
905 | |
906 | ForLoopDecoratorRecorder() { |
907 | reset(); |
908 | } |
909 | |
910 | void reset() { |
911 | config.is_bit_vectorized = false; |
912 | config.num_cpu_threads = 0; |
913 | config.uniform = false; |
914 | config.mem_access_opt.clear(); |
915 | config.block_dim = 0; |
916 | config.strictly_serialized = false; |
917 | } |
918 | }; |
919 | |
920 | std::vector<Block *> stack_; |
921 | std::vector<LoopState> loop_state_stack_; |
922 | Arch arch_; |
923 | ForLoopDecoratorRecorder for_loop_dec_; |
924 | int id_counter_{0}; |
925 | |
926 | public: |
927 | ASTBuilder(Block *initial, Arch arch) : arch_(arch) { |
928 | stack_.push_back(initial); |
929 | loop_state_stack_.push_back(None); |
930 | } |
931 | |
932 | void insert(std::unique_ptr<Stmt> &&stmt, int location = -1); |
933 | |
934 | Block *current_block(); |
935 | Stmt *get_last_stmt(); |
936 | void stop_gradient(SNode *); |
937 | void insert_assignment(Expr &lhs, |
938 | const Expr &rhs, |
939 | const std::string &tb = "" ); |
940 | Expr make_var(const Expr &x, std::string tb); |
941 | void insert_for(const Expr &s, |
942 | const Expr &e, |
943 | const std::function<void(Expr)> &func); |
944 | |
945 | Expr make_id_expr(const std::string &name); |
946 | Expr make_matrix_expr(const std::vector<int> &shape, |
947 | const DataType &dt, |
948 | const std::vector<Expr> &elements); |
949 | Expr insert_thread_idx_expr(); |
950 | Expr insert_patch_idx_expr(); |
951 | void create_kernel_exprgroup_return(const ExprGroup &group); |
952 | void create_print(std::vector<std::variant<Expr, std::string>> contents); |
953 | void begin_func(const std::string &funcid); |
954 | void end_func(const std::string &funcid); |
955 | void begin_frontend_if(const Expr &cond); |
956 | void begin_frontend_if_true(); |
957 | void begin_frontend_if_false(); |
958 | void insert_external_func_call(std::size_t func_addr, |
959 | std::string source, |
960 | std::string filename, |
961 | std::string funcname, |
962 | const ExprGroup &args, |
963 | const ExprGroup &outputs); |
964 | Expr expr_alloca(); |
965 | Expr expr_alloca_shared_array(const std::vector<int> &shape, |
966 | const DataType &element_type); |
967 | Expr expr_subscript(const Expr &expr, |
968 | const ExprGroup &indices, |
969 | std::string tb = "" ); |
970 | |
971 | Expr mesh_index_conversion(mesh::MeshPtr mesh_ptr, |
972 | mesh::MeshElementType idx_type, |
973 | const Expr &idx, |
974 | mesh::ConvType &conv_type); |
975 | |
976 | void expr_assign(const Expr &lhs, const Expr &rhs, std::string tb); |
977 | std::optional<Expr> insert_func_call(Function *func, const ExprGroup &args); |
978 | void create_assert_stmt(const Expr &cond, |
979 | const std::string &msg, |
980 | const std::vector<Expr> &args); |
981 | void begin_frontend_range_for(const Expr &i, const Expr &s, const Expr &e); |
982 | void begin_frontend_struct_for_on_snode(const ExprGroup &loop_vars, |
983 | SNode *snode); |
984 | void begin_frontend_struct_for_on_external_tensor( |
985 | const ExprGroup &loop_vars, |
986 | const Expr &external_tensor); |
987 | void begin_frontend_mesh_for(const Expr &i, |
988 | const mesh::MeshPtr &mesh_ptr, |
989 | const mesh::MeshElementType &element_type); |
990 | void begin_frontend_while(const Expr &cond); |
991 | void insert_break_stmt(); |
992 | void insert_continue_stmt(); |
993 | void insert_expr_stmt(const Expr &val); |
994 | void insert_snode_activate(SNode *snode, const ExprGroup &expr_group); |
995 | void insert_snode_deactivate(SNode *snode, const ExprGroup &expr_group); |
996 | Expr make_texture_op_expr(const TextureOpType &op, |
997 | const Expr &texture_ptr, |
998 | const ExprGroup &args); |
999 | /* |
1000 | * This function allocates the space for a new item (a struct or a scalar) |
1001 | * in the Dynamic SNode, and assigns values to the elements inside it. |
1002 | * |
1003 | * When appending a struct, the size of vals must be equal to |
1004 | * the number of elements in the struct. When appending a scalar, |
1005 | * the size of vals must be one. |
1006 | */ |
1007 | Expr snode_append(SNode *snode, |
1008 | const ExprGroup &indices, |
1009 | const std::vector<Expr> &vals); |
1010 | Expr snode_is_active(SNode *snode, const ExprGroup &indices); |
1011 | Expr snode_length(SNode *snode, const ExprGroup &indices); |
1012 | Expr snode_get_addr(SNode *snode, const ExprGroup &indices); |
1013 | |
1014 | std::vector<Expr> expand_exprs(const std::vector<Expr> &exprs); |
1015 | |
1016 | void create_scope(std::unique_ptr<Block> &list, LoopType tp = NotLoop); |
1017 | void pop_scope(); |
1018 | |
1019 | void bit_vectorize() { |
1020 | for_loop_dec_.config.is_bit_vectorized = true; |
1021 | } |
1022 | |
1023 | void parallelize(int v) { |
1024 | for_loop_dec_.config.num_cpu_threads = v; |
1025 | } |
1026 | |
1027 | void strictly_serialize() { |
1028 | for_loop_dec_.config.strictly_serialized = true; |
1029 | } |
1030 | |
1031 | void block_dim(int v) { |
1032 | if (arch_ == Arch::cuda || arch_ == Arch::vulkan || arch_ == Arch::amdgpu) { |
1033 | TI_ASSERT((v % 32 == 0) || bit::is_power_of_two(v)); |
1034 | } else { |
1035 | TI_ASSERT(bit::is_power_of_two(v)); |
1036 | } |
1037 | for_loop_dec_.config.block_dim = v; |
1038 | } |
1039 | |
1040 | void insert_snode_access_flag(SNodeAccessFlag v, const Expr &field) { |
1041 | for_loop_dec_.config.mem_access_opt.add_flag(field.snode(), v); |
1042 | } |
1043 | |
1044 | void reset_snode_access_flag() { |
1045 | for_loop_dec_.reset(); |
1046 | } |
1047 | |
1048 | Identifier get_next_id(const std::string &name = "" ) { |
1049 | return Identifier(id_counter_++, name); |
1050 | } |
1051 | }; |
1052 | |
1053 | class FrontendContext { |
1054 | private: |
1055 | std::unique_ptr<ASTBuilder> current_builder_; |
1056 | std::unique_ptr<Block> root_node_; |
1057 | |
1058 | public: |
1059 | explicit FrontendContext(Arch arch) { |
1060 | root_node_ = std::make_unique<Block>(); |
1061 | current_builder_ = std::make_unique<ASTBuilder>(root_node_.get(), arch); |
1062 | } |
1063 | |
1064 | ASTBuilder &builder() { |
1065 | return *current_builder_; |
1066 | } |
1067 | |
1068 | std::unique_ptr<Block> get_root() { |
1069 | return std::move(root_node_); |
1070 | } |
1071 | }; |
1072 | |
1073 | Stmt *flatten_lvalue(Expr expr, Expression::FlattenContext *ctx); |
1074 | |
1075 | Stmt *flatten_rvalue(Expr expr, Expression::FlattenContext *ctx); |
1076 | |
1077 | } // namespace taichi::lang |
1078 | |