1#include "taichi/ir/frontend_ir.h"
2
3#include "taichi/ir/expression_printer.h"
4#include "taichi/ir/statements.h"
5#include "taichi/program/program.h"
6#include "taichi/common/exceptions.h"
7
8#include <numeric>
9
10namespace taichi::lang {
11
12#define TI_ASSERT_TYPE_CHECKED(x) \
13 TI_ASSERT_INFO(x->ret_type != PrimitiveType::unknown, \
14 "[{}] was not type-checked", \
15 ExpressionHumanFriendlyPrinter::expr_to_string(x))
16
17static bool is_primitive_or_tensor_type(DataType &type) {
18 return type->is<PrimitiveType>() || type->is<TensorType>();
19}
20
21FrontendSNodeOpStmt::FrontendSNodeOpStmt(SNodeOpType op_type,
22 SNode *snode,
23 const ExprGroup &indices,
24 const Expr &val)
25 : op_type(op_type), snode(snode), indices(indices), val(val) {
26 if (val.expr != nullptr) {
27 TI_ASSERT(op_type == SNodeOpType::append);
28 } else {
29 TI_ASSERT(op_type != SNodeOpType::append);
30 }
31}
32
33FrontendReturnStmt::FrontendReturnStmt(const ExprGroup &group) : values(group) {
34}
35
36FrontendAssignStmt::FrontendAssignStmt(const Expr &lhs, const Expr &rhs)
37 : lhs(lhs), rhs(rhs) {
38 TI_ASSERT(lhs->is_lvalue());
39 if (lhs.is<IdExpression>() && lhs->ret_type == PrimitiveType::unknown) {
40 lhs.expr->ret_type = rhs->ret_type;
41 }
42}
43
44FrontendForStmt::FrontendForStmt(const ExprGroup &loop_vars,
45 SNode *snode,
46 Arch arch,
47 const ForLoopConfig &config)
48 : snode(snode) {
49 init_config(arch, config);
50 init_loop_vars(loop_vars);
51}
52
53FrontendForStmt::FrontendForStmt(const ExprGroup &loop_vars,
54 const Expr &external_tensor,
55 Arch arch,
56 const ForLoopConfig &config)
57 : external_tensor(external_tensor) {
58 init_config(arch, config);
59 init_loop_vars(loop_vars);
60}
61
62FrontendForStmt::FrontendForStmt(const ExprGroup &loop_vars,
63 const mesh::MeshPtr &mesh,
64 const mesh::MeshElementType &element_type,
65 Arch arch,
66 const ForLoopConfig &config)
67 : mesh(mesh.ptr.get()), element_type(element_type) {
68 init_config(arch, config);
69 init_loop_vars(loop_vars);
70}
71
72FrontendForStmt::FrontendForStmt(const Expr &loop_var,
73 const Expr &begin,
74 const Expr &end,
75 Arch arch,
76 const ForLoopConfig &config)
77 : begin(begin), end(end) {
78 init_config(arch, config);
79 add_loop_var(loop_var);
80}
81
82void FrontendForStmt::init_config(Arch arch, const ForLoopConfig &config) {
83 is_bit_vectorized = config.is_bit_vectorized;
84 strictly_serialized = config.strictly_serialized;
85 mem_access_opt = config.mem_access_opt;
86 block_dim = config.block_dim;
87 if (arch == Arch::cuda || arch == Arch::amdgpu) {
88 num_cpu_threads = 1;
89 TI_ASSERT(block_dim <= taichi_max_gpu_block_dim);
90 } else { // cpu
91 if (config.num_cpu_threads == 0) {
92 num_cpu_threads = std::thread::hardware_concurrency();
93 } else {
94 num_cpu_threads = config.num_cpu_threads;
95 }
96 }
97}
98
99void FrontendForStmt::init_loop_vars(const ExprGroup &loop_vars) {
100 loop_var_ids.reserve(loop_vars.size());
101 for (int i = 0; i < (int)loop_vars.size(); i++) {
102 add_loop_var(loop_vars[i]);
103 }
104}
105
106void FrontendForStmt::add_loop_var(const Expr &loop_var) {
107 loop_var_ids.push_back(loop_var.cast<IdExpression>()->id);
108 loop_var.expr->ret_type = PrimitiveType::i32;
109}
110
111void ArgLoadExpression::type_check(const CompileConfig *) {
112 TI_ASSERT_INFO(dt->is<PrimitiveType>() && dt != PrimitiveType::unknown,
113 "Invalid dt [{}] for ArgLoadExpression", dt->to_string());
114 ret_type = dt;
115}
116
117void ArgLoadExpression::flatten(FlattenContext *ctx) {
118 auto arg_load = std::make_unique<ArgLoadStmt>(arg_id, dt, is_ptr);
119 ctx->push_back(std::move(arg_load));
120 stmt = ctx->back_stmt();
121}
122
123void TexturePtrExpression::type_check(const CompileConfig *config) {
124}
125
126void TexturePtrExpression::flatten(FlattenContext *ctx) {
127 ctx->push_back<ArgLoadStmt>(arg_id, PrimitiveType::f32, true);
128 ctx->push_back<TexturePtrStmt>(ctx->back_stmt(), num_dims, is_storage,
129 num_channels, channel_format, lod);
130 stmt = ctx->back_stmt();
131}
132
133void RandExpression::type_check(const CompileConfig *) {
134 TI_ASSERT_INFO(dt->is<PrimitiveType>() && dt != PrimitiveType::unknown,
135 "Invalid dt [{}] for RandExpression", dt->to_string());
136 ret_type = dt;
137}
138
139void RandExpression::flatten(FlattenContext *ctx) {
140 auto ran = std::make_unique<RandStmt>(dt);
141 ctx->push_back(std::move(ran));
142 stmt = ctx->back_stmt();
143}
144
145void UnaryOpExpression::type_check(const CompileConfig *config) {
146 TI_ASSERT_TYPE_CHECKED(operand);
147
148 TI_ASSERT(config != nullptr);
149 /*
150 Dtype inference for both TensorType and PrimitiveType are essentially
151 the same. Therefore we extract the primitive type to perform the type
152 inference, and then reconstruct the TensorType once neccessary.
153 */
154
155 auto operand_primitive_type = operand->ret_type.get_element_type();
156 auto ret_primitive_type = ret_type;
157
158 if (!operand_primitive_type->is<PrimitiveType>()) {
159 throw TaichiTypeError(fmt::format(
160 "unsupported operand type(s) for '{}': '{}'", unary_op_type_name(type),
161 operand_primitive_type->to_string()));
162 }
163
164 if ((type == UnaryOpType::round || type == UnaryOpType::floor ||
165 type == UnaryOpType::ceil || is_trigonometric(type)) &&
166 !is_real(operand_primitive_type))
167 throw TaichiTypeError(fmt::format(
168 "'{}' takes real inputs only, however '{}' is provided",
169 unary_op_type_name(type), operand_primitive_type->to_string()));
170
171 if ((type == UnaryOpType::sqrt || type == UnaryOpType::exp ||
172 type == UnaryOpType::log) &&
173 !is_real(operand_primitive_type)) {
174 ret_primitive_type = config->default_fp;
175 } else {
176 ret_primitive_type = is_cast() ? cast_type : operand_primitive_type;
177 }
178
179 if (operand->ret_type->is<TensorType>()) {
180 ret_type = taichi::lang::TypeFactory::get_instance().get_tensor_type(
181 operand->ret_type.get_shape(), ret_primitive_type);
182 } else {
183 TI_ASSERT(operand->ret_type->is<PrimitiveType>());
184 ret_type = ret_primitive_type;
185 }
186}
187
188bool UnaryOpExpression::is_cast() const {
189 return unary_op_is_cast(type);
190}
191
192void UnaryOpExpression::flatten(FlattenContext *ctx) {
193 auto operand_stmt = flatten_rvalue(operand, ctx);
194 auto unary = std::make_unique<UnaryOpStmt>(type, operand_stmt);
195 if (is_cast()) {
196 unary->cast_type = cast_type;
197 }
198 stmt = unary.get();
199 stmt->tb = tb;
200 stmt->ret_type = ret_type;
201 ctx->push_back(std::move(unary));
202}
203
204Expr to_broadcast_tensor(const Expr &elt, const DataType &dt) {
205 if (!elt->ret_type->is<TensorType>() && !dt->is<TensorType>())
206 return elt;
207
208 if (elt->ret_type->is<TensorType>() && dt->is<TensorType>()) {
209 // Only tensor shape will be checked here, since the dtype will
210 // be promoted later at irpass::type_check()
211 if (elt->ret_type.get_shape() != dt.get_shape()) {
212 TI_ERROR("Cannot broadcast tensor to tensor");
213 } else {
214 return elt;
215 }
216 }
217
218 auto tensor_type = dt->as<TensorType>();
219 auto elt_type = tensor_type->get_element_type();
220 TI_ASSERT_INFO(elt_type->is<PrimitiveType>(),
221 "Only primitive types are supported in Tensors, got {}",
222 elt_type->to_string());
223 std::vector<Expr> broadcast_values(tensor_type->get_num_elements(), elt);
224 auto matrix_expr = Expr::make<MatrixExpression>(
225 broadcast_values, tensor_type->get_shape(), elt->ret_type);
226 matrix_expr->type_check(nullptr);
227 return matrix_expr;
228}
229
230std::tuple<Expr, Expr> unify_binop_operands(const Expr &e1, const Expr &e2) {
231 if (e1->ret_type->is<PrimitiveType>() && e2->ret_type->is<TensorType>()) {
232 return std::tuple(to_broadcast_tensor(e1, e2->ret_type), e2);
233 } else if (e1->ret_type->is<TensorType>() &&
234 e2->ret_type->is<PrimitiveType>()) {
235 return std::tuple(e1, to_broadcast_tensor(e2, e1->ret_type));
236 } else {
237 return std::tuple(e1, e2);
238 }
239}
240
241void BinaryOpExpression::type_check(const CompileConfig *config) {
242 TI_ASSERT_TYPE_CHECKED(lhs);
243 TI_ASSERT_TYPE_CHECKED(rhs);
244 auto lhs_type = lhs->ret_type;
245 auto rhs_type = rhs->ret_type;
246 auto error = [&]() {
247 throw TaichiTypeError(
248 fmt::format("unsupported operand type(s) for '{}': '{}' and '{}'",
249 binary_op_type_symbol(type), lhs->ret_type->to_string(),
250 rhs->ret_type->to_string()));
251 };
252
253 if (!is_primitive_or_tensor_type(lhs_type) ||
254 !is_primitive_or_tensor_type(rhs_type)) {
255 error();
256 }
257
258 if ((lhs_type->is<PrimitiveType>() && rhs_type->is<TensorType>()) ||
259 (lhs_type->is<TensorType>() && rhs_type->is<PrimitiveType>())) {
260 // convert Tensor/Scalar | Scalar/Tensor operations to broadcasting
261 auto [unified_l, unified_r] = unify_binop_operands(lhs, rhs);
262 lhs = unified_l;
263 rhs = unified_r;
264 if (lhs->ret_type == PrimitiveType::unknown)
265 lhs.type_check(config);
266 if (rhs->ret_type == PrimitiveType::unknown)
267 rhs.type_check(config);
268 TI_ASSERT(lhs->ret_type->is<TensorType>());
269 TI_ASSERT(rhs->ret_type->is<TensorType>());
270 lhs_type = lhs->ret_type;
271 rhs_type = rhs->ret_type;
272 }
273
274 bool is_tensor_op = false;
275
276 if (lhs_type->is<TensorType>()) {
277 is_tensor_op = true;
278 auto rhs_tensor_type = rhs_type->cast<TensorType>();
279 if (rhs_tensor_type->get_shape() !=
280 lhs_type->cast<TensorType>()->get_shape())
281 // current assume element-wise binary op
282 error();
283 }
284
285 auto make_dt = [&is_tensor_op, this](DataType dt) {
286 if (is_tensor_op) {
287 return TypeFactory::create_tensor_type(
288 this->lhs->ret_type->cast<TensorType>()->get_shape(), dt);
289 } else {
290 return dt;
291 }
292 };
293
294 if (binary_is_bitwise(type) && (!is_integral(lhs_type.get_element_type()) ||
295 !is_integral(rhs_type.get_element_type())))
296 error();
297 if (binary_is_logical(type) &&
298 (is_tensor_op || lhs_type != PrimitiveType::i32 ||
299 rhs_type != PrimitiveType::i32))
300 error();
301 if (is_comparison(type) || binary_is_logical(type)) {
302 ret_type = make_dt(PrimitiveType::i32);
303 return;
304 }
305 if (is_shift_op(type) ||
306 (type == BinaryOpType::pow && is_integral(rhs_type))) {
307 ret_type = lhs_type;
308 return;
309 }
310
311 // Some backends such as vulkan doesn't support fp64
312 // Try not promoting to fp64 unless necessary
313 if (type == BinaryOpType::atan2) {
314 if (lhs_type == PrimitiveType::f64 || rhs_type == PrimitiveType::f64) {
315 ret_type = make_dt(PrimitiveType::f64);
316 } else {
317 ret_type = make_dt(PrimitiveType::f32);
318 }
319 return;
320 }
321
322 if (type == BinaryOpType::truediv) {
323 auto default_fp = config->default_fp;
324 if (!is_real(lhs_type.get_element_type())) {
325 lhs_type = make_dt(default_fp);
326 }
327 if (!is_real(rhs_type.get_element_type())) {
328 rhs_type = make_dt(default_fp);
329 }
330 }
331 ret_type = promoted_type(lhs_type, rhs_type);
332}
333
334void BinaryOpExpression::flatten(FlattenContext *ctx) {
335 // if (stmt)
336 // return;
337 auto lhs_stmt = flatten_rvalue(lhs, ctx);
338
339 if (binary_is_logical(type)) {
340 auto result = ctx->push_back<AllocaStmt>(ret_type);
341 ctx->push_back<LocalStoreStmt>(result, lhs_stmt);
342 auto cond = ctx->push_back<LocalLoadStmt>(result);
343 auto if_stmt = ctx->push_back<IfStmt>(cond);
344
345 FlattenContext rctx;
346 rctx.current_block = ctx->current_block;
347 auto rhs_stmt = flatten_rvalue(rhs, &rctx);
348 rctx.push_back<LocalStoreStmt>(result, rhs_stmt);
349
350 auto true_block = std::make_unique<Block>();
351 if (type == BinaryOpType::logical_and) {
352 true_block->set_statements(std::move(rctx.stmts));
353 }
354 if_stmt->set_true_statements(std::move(true_block));
355
356 auto false_block = std::make_unique<Block>();
357 if (type == BinaryOpType::logical_or) {
358 false_block->set_statements(std::move(rctx.stmts));
359 }
360 if_stmt->set_false_statements(std::move(false_block));
361
362 auto ret = ctx->push_back<LocalLoadStmt>(result);
363 ret->tb = tb;
364 stmt = ret;
365 stmt->ret_type = ret_type;
366 return;
367 }
368 auto rhs_stmt = flatten_rvalue(rhs, ctx);
369 ctx->push_back(std::make_unique<BinaryOpStmt>(type, lhs_stmt, rhs_stmt));
370 ctx->stmts.back()->tb = tb;
371 stmt = ctx->back_stmt();
372 stmt->ret_type = ret_type;
373}
374
375void make_ifte(Expression::FlattenContext *ctx,
376 DataType ret_type,
377 Expr cond,
378 Expr true_val,
379 Expr false_val) {
380 auto result = ctx->push_back<AllocaStmt>(ret_type);
381 auto cond_stmt = flatten_rvalue(cond, ctx);
382 auto if_stmt = ctx->push_back<IfStmt>(cond_stmt);
383
384 Expression::FlattenContext lctx;
385 lctx.current_block = ctx->current_block;
386 auto true_val_stmt = flatten_rvalue(true_val, &lctx);
387 lctx.push_back<LocalStoreStmt>(result, true_val_stmt);
388
389 Expression::FlattenContext rctx;
390 rctx.current_block = ctx->current_block;
391 auto false_val_stmt = flatten_rvalue(false_val, &rctx);
392 rctx.push_back<LocalStoreStmt>(result, false_val_stmt);
393
394 auto true_block = std::make_unique<Block>();
395 true_block->set_statements(std::move(lctx.stmts));
396 if_stmt->set_true_statements(std::move(true_block));
397
398 auto false_block = std::make_unique<Block>();
399 false_block->set_statements(std::move(rctx.stmts));
400 if_stmt->set_false_statements(std::move(false_block));
401
402 ctx->push_back<LocalLoadStmt>(result);
403 return;
404}
405
406static std::tuple<Expr, Expr, Expr> unify_ternaryop_operands(const Expr &e1,
407 const Expr &e2,
408 const Expr &e3) {
409 auto target_dtype = PrimitiveType::unknown;
410 // Since we don't support broadcasting between two TensorTypes,
411 // we can simply use the first TensorType's dtype as the target dtype.
412 if (e1->ret_type->is<TensorType>()) {
413 target_dtype = e1->ret_type;
414 } else if (e2->ret_type->is<TensorType>()) {
415 target_dtype = e2->ret_type;
416 } else if (e3->ret_type->is<TensorType>()) {
417 target_dtype = e3->ret_type;
418 }
419
420 if (target_dtype == PrimitiveType::unknown) {
421 return std::tuple(e1, e2, e3);
422 }
423
424 return std::tuple(to_broadcast_tensor(e1, target_dtype),
425 to_broadcast_tensor(e2, target_dtype),
426 to_broadcast_tensor(e3, target_dtype));
427}
428
429void TernaryOpExpression::type_check(const CompileConfig *config) {
430 TI_ASSERT_TYPE_CHECKED(op1);
431 TI_ASSERT_TYPE_CHECKED(op2);
432 TI_ASSERT_TYPE_CHECKED(op3);
433
434 bool is_valid = true;
435 bool is_tensor = false;
436
437 auto [unified_cond, unified_l, unified_r] =
438 unify_ternaryop_operands(op1, op2, op3);
439 op1 = unified_cond;
440 op2 = unified_l;
441 op3 = unified_r;
442 auto op1_type = op1->ret_type;
443 auto op2_type = op2->ret_type;
444 auto op3_type = op3->ret_type;
445
446 auto error = [&]() {
447 throw TaichiTypeError(
448 fmt::format("unsupported operand type(s) for '{}': '{}', '{}' and '{}'",
449 ternary_type_name(type), op1->ret_type->to_string(),
450 op2->ret_type->to_string(), op3->ret_type->to_string()));
451 };
452
453 if (op1_type->is<TensorType>() && op2_type->is<TensorType>() &&
454 op3_type->is<TensorType>()) {
455 // valid
456 is_tensor = true;
457 if (op1_type->cast<TensorType>()->get_shape() !=
458 op2_type->cast<TensorType>()->get_shape()) {
459 is_valid = false;
460 }
461 if (op2_type->cast<TensorType>()->get_shape() !=
462 op3_type->cast<TensorType>()->get_shape()) {
463 is_valid = false;
464 }
465 op1_type = op1_type->cast<TensorType>()->get_element_type();
466 op2_type = op2_type->cast<TensorType>()->get_element_type();
467 op3_type = op3_type->cast<TensorType>()->get_element_type();
468
469 } else if (op1_type->is<PrimitiveType>() && op2_type->is<PrimitiveType>() &&
470 op3_type->is<PrimitiveType>()) {
471 // valid
472 } else {
473 is_valid = false;
474 }
475
476 if (op1_type != PrimitiveType::i32) {
477 is_valid = false;
478 }
479 if (!op2_type->is<PrimitiveType>() || !op3_type->is<PrimitiveType>()) {
480 is_valid = false;
481 }
482
483 if (!is_valid)
484 error();
485
486 if (is_tensor) {
487 auto primitive_dtype = promoted_type(op2_type, op3_type);
488 auto shape = op2->ret_type->cast<TensorType>()->get_shape();
489 ret_type = TypeFactory::create_tensor_type(shape, primitive_dtype);
490 } else {
491 ret_type = promoted_type(op2_type, op3_type);
492 }
493}
494
495void TernaryOpExpression::flatten(FlattenContext *ctx) {
496 // if (stmt)
497 // return;
498 if (type == TernaryOpType::select) {
499 auto op1_stmt = flatten_rvalue(op1, ctx);
500 auto op2_stmt = flatten_rvalue(op2, ctx);
501 auto op3_stmt = flatten_rvalue(op3, ctx);
502 ctx->push_back(
503 std::make_unique<TernaryOpStmt>(type, op1_stmt, op2_stmt, op3_stmt));
504 } else if (type == TernaryOpType::ifte) {
505 make_ifte(ctx, ret_type, op1, op2, op3);
506 }
507 stmt = ctx->back_stmt();
508 stmt->tb = tb;
509 stmt->ret_type = ret_type;
510}
511
512void InternalFuncCallExpression::type_check(const CompileConfig *) {
513 for (auto &arg : args) {
514 TI_ASSERT_TYPE_CHECKED(arg);
515 // no arg type compatibility check for now due to lack of specification
516 }
517 // internal func calls have default return type
518 ret_type = PrimitiveType::i32;
519}
520
521void InternalFuncCallExpression::flatten(FlattenContext *ctx) {
522 std::vector<Stmt *> args_stmts(args.size());
523 for (int i = 0; i < (int)args.size(); ++i) {
524 args_stmts[i] = flatten_rvalue(args[i], ctx);
525 }
526 ctx->push_back<InternalFuncStmt>(func_name, args_stmts, nullptr,
527 with_runtime_context);
528 stmt = ctx->back_stmt();
529 stmt->tb = tb;
530}
531
532void ExternalTensorExpression::flatten(FlattenContext *ctx) {
533 // https://github.com/taichi-dev/taichi/issues/5819
534 // ArgLoadStmt keeps primitive types since all matrix-type gets
535 // scalarized at python-scope
536 //
537 // FIXME(zhanlue): ArgLoadStmt should use TensorType once real_matrix is
538 // turned-on by default.
539 // The scalarization should happen after
540 // irpass::lower_access()
541 auto prim_dt = dt;
542 auto ptr = Stmt::make<ArgLoadStmt>(arg_id, prim_dt, /*is_ptr=*/true,
543 /*is_grad=*/is_grad);
544
545 int external_dims = dim - std::abs(element_dim);
546 ptr->cast<ArgLoadStmt>()->set_extern_dims(external_dims);
547
548 ptr->tb = tb;
549 ctx->push_back(std::move(ptr));
550 stmt = ctx->back_stmt();
551}
552
553std::vector<Stmt *> make_index_stmts(Expression::FlattenContext *ctx,
554 const ExprGroup &indices,
555 const std::vector<int> &offsets) {
556 std::vector<Stmt *> index_stmts;
557 for (int i = 0; i < (int)indices.size(); i++) {
558 Stmt *ind = flatten_rvalue(indices.exprs[i], ctx);
559 if (!offsets.empty()) {
560 auto offset = ctx->push_back<ConstStmt>(TypedConstant(offsets[i]));
561 ind = ctx->push_back<BinaryOpStmt>(BinaryOpType::sub, ind, offset);
562 }
563 index_stmts.push_back(ind);
564 }
565 return index_stmts;
566}
567
568Stmt *make_field_access(Expression::FlattenContext *ctx,
569 const FieldExpression &field,
570 ExprGroup indices) {
571 return ctx->push_back(std::make_unique<GlobalPtrStmt>(
572 field.snode, make_index_stmts(ctx, indices, field.snode->index_offsets)));
573}
574
575Stmt *make_matrix_field_access(Expression::FlattenContext *ctx,
576 const MatrixFieldExpression &matrix_field,
577 ExprGroup indices,
578 DataType ret_type) {
579 std::vector<SNode *> snodes;
580 for (auto &field : matrix_field.fields) {
581 snodes.push_back(field.cast<FieldExpression>()->snode);
582 }
583 return ctx->push_back(std::make_unique<MatrixOfGlobalPtrStmt>(
584 snodes, make_index_stmts(ctx, indices, snodes[0]->index_offsets),
585 matrix_field.dynamic_indexable, matrix_field.dynamic_index_stride,
586 ret_type));
587}
588
589Stmt *make_ndarray_access(Expression::FlattenContext *ctx,
590 Expr var,
591 ExprGroup indices) {
592 std::vector<Stmt *> index_stmts;
593 for (int i = 0; i < (int)indices.size(); i++) {
594 Stmt *ind = flatten_rvalue(indices.exprs[i], ctx);
595 index_stmts.push_back(ind);
596 }
597 auto var_stmt = flatten_lvalue(var, ctx);
598 auto expr = var.cast<ExternalTensorExpression>();
599 auto external_ptr_stmt = std::make_unique<ExternalPtrStmt>(
600 var_stmt, index_stmts, expr->dt.get_shape(), expr->element_dim);
601 if (expr->dim == indices.size()) {
602 // Indexing into an scalar element
603 external_ptr_stmt->ret_type = expr->dt.ptr_removed().get_element_type();
604 } else {
605 // Indexing outer dimensions
606 external_ptr_stmt->ret_type = expr->dt.ptr_removed();
607 }
608
609 return ctx->push_back(std::move(external_ptr_stmt));
610}
611
612Stmt *make_tensor_access_single_element(Expression::FlattenContext *ctx,
613 Stmt *var_stmt,
614 const ExprGroup &indices,
615 const std::vector<int> &shape,
616 const std::string &tb) {
617 bool needs_dynamic_index = false;
618 for (int i = 0; i < (int)indices.size(); ++i) {
619 if (!indices[i].is<ConstExpression>()) {
620 needs_dynamic_index = true;
621 }
622 }
623 Stmt *offset_stmt = nullptr;
624 if (needs_dynamic_index) {
625 offset_stmt = ctx->push_back<ConstStmt>(TypedConstant(0));
626 for (int i = 0; i < (int)indices.size(); ++i) {
627 auto index_stmt = flatten_rvalue(indices[i], ctx);
628 Stmt *shape_stmt = ctx->push_back<ConstStmt>(TypedConstant(shape[i]));
629 Stmt *mul_stmt = ctx->push_back<BinaryOpStmt>(BinaryOpType::mul,
630 offset_stmt, shape_stmt);
631 offset_stmt =
632 ctx->push_back<BinaryOpStmt>(BinaryOpType::add, mul_stmt, index_stmt);
633 }
634 } else {
635 int offset = 0;
636 for (int i = 0; i < (int)indices.size(); ++i) {
637 offset =
638 offset * shape[i] + indices[i].cast<ConstExpression>()->val.val_int();
639 }
640 offset_stmt = ctx->push_back<ConstStmt>(TypedConstant(offset));
641 }
642 return ctx->push_back<MatrixPtrStmt>(var_stmt, offset_stmt, tb);
643}
644
645Stmt *make_tensor_access(Expression::FlattenContext *ctx,
646 Expr var,
647 const std::vector<ExprGroup> &indices_group,
648 DataType ret_type,
649 std::vector<int> shape,
650 const std::string &tb) {
651 auto var_stmt = flatten_lvalue(var, ctx);
652 if (!var->is_lvalue()) {
653 auto alloca_stmt = ctx->push_back<AllocaStmt>(var->ret_type);
654 ctx->push_back<LocalStoreStmt>(alloca_stmt, var_stmt);
655 var_stmt = alloca_stmt;
656 }
657 if (is_tensor(ret_type)) {
658 std::vector<Stmt *> stmts;
659 for (auto &indices : indices_group) {
660 stmts.push_back(
661 make_tensor_access_single_element(ctx, var_stmt, indices, shape, tb));
662 }
663 return ctx->push_back<MatrixOfMatrixPtrStmt>(stmts, ret_type);
664 }
665 return make_tensor_access_single_element(ctx, var_stmt, indices_group[0],
666 shape, tb);
667}
668
669void MatrixExpression::type_check(const CompileConfig *config) {
670 TI_ASSERT(dt->as<TensorType>()->get_num_elements() == elements.size());
671
672 for (auto &arg : elements) {
673 TI_ASSERT_TYPE_CHECKED(arg);
674 if (arg->ret_type != dt.get_element_type()) {
675 arg = cast(arg, dt.get_element_type());
676 arg->type_check(config);
677 }
678 }
679 ret_type = dt;
680}
681
682void MatrixExpression::flatten(FlattenContext *ctx) {
683 TI_ASSERT(this->dt->is<TensorType>());
684 std::vector<Stmt *> values;
685 for (auto &elt : elements) {
686 values.push_back(flatten_rvalue(elt, ctx));
687 }
688 stmt = ctx->push_back<MatrixInitStmt>(values);
689 stmt->ret_type = this->dt;
690}
691
692IndexExpression::IndexExpression(const Expr &var,
693 const ExprGroup &indices,
694 std::string tb)
695 : var(var), indices_group({indices}) {
696 this->tb = tb;
697}
698
699IndexExpression::IndexExpression(const Expr &var,
700 const std::vector<ExprGroup> &indices_group,
701 const std::vector<int> &ret_shape,
702 std::string tb)
703 : var(var), indices_group(indices_group), ret_shape(ret_shape) {
704 // IndexExpression with ret_shape is used for matrix slicing, where each entry
705 // of ExprGroup is interpreted as a group of indices to return within each
706 // axis. For example, mat[0, 3:5] has indices_group={0, [3, 4]}, where [3, 4]
707 // means "m"-axis will return a TensorType with size of 2. In this case, we
708 // should not expand indices_group due to its special semantics.
709 this->tb = tb;
710}
711
712bool IndexExpression::is_field() const {
713 return var.is<FieldExpression>();
714}
715
716bool IndexExpression::is_matrix_field() const {
717 return var.is<MatrixFieldExpression>();
718}
719
720bool IndexExpression::is_ndarray() const {
721 return var.is<ExternalTensorExpression>();
722}
723
724bool IndexExpression::is_tensor() const {
725 return var->ret_type->is<TensorType>();
726}
727
728bool IndexExpression::is_local() const {
729 return !is_global();
730}
731
732bool IndexExpression::is_global() const {
733 // Special case: Indexing into TensorType-element of ExternalPtrStmt
734 // or GlobalPtrStmt should be treated as global ptrs
735 if (var.is<IndexExpression>()) {
736 TI_ASSERT(var.cast<IndexExpression>()->is_matrix_field() ||
737 var.cast<IndexExpression>()->is_ndarray());
738 return true;
739 }
740
741 // Only Ndarray and Field comes outside from a kernel
742 return is_field() || is_matrix_field() || is_ndarray();
743}
744
745static void field_validation(FieldExpression *field_expr, int index_dim) {
746 TI_ASSERT(field_expr != nullptr);
747 TI_ASSERT(field_expr->snode != nullptr);
748 int field_dim = field_expr->snode->num_active_indices;
749
750 if (field_dim != index_dim) {
751 throw TaichiIndexError(
752 fmt::format("Field with dim {} accessed with indices of dim {}",
753 field_dim, index_dim));
754 }
755}
756
757void IndexExpression::type_check(const CompileConfig *) {
758 // TODO: Change to type-based solution
759 // Currently, dimension compatibility check happens in Python
760 TI_ASSERT(indices_group.size() == std::accumulate(begin(ret_shape),
761 end(ret_shape), 1,
762 std::multiplies<>()));
763 int index_dim = indices_group.empty() ? 0 : indices_group[0].size();
764 bool has_slice = !ret_shape.empty();
765 if (has_slice) {
766 TI_ASSERT_INFO(is_tensor(), "Slice or swizzle can only apply on matrices");
767 auto element_type = var->ret_type->as<TensorType>()->get_element_type();
768 ret_type = TypeFactory::create_tensor_type(ret_shape, element_type);
769
770 } else if (is_field()) { // field
771 auto field_expr = var.cast<FieldExpression>();
772 field_validation(field_expr.get(), index_dim);
773 ret_type = field_expr->dt->get_compute_type();
774
775 } else if (is_matrix_field()) {
776 auto matrix_field_expr = var.cast<MatrixFieldExpression>();
777
778 TI_ASSERT(!matrix_field_expr->fields.empty());
779 auto field_expr = matrix_field_expr->fields[0].cast<FieldExpression>();
780 field_validation(field_expr.get(), index_dim);
781
782 ret_type = TypeFactory::create_tensor_type(matrix_field_expr->element_shape,
783 matrix_field_expr->fields[0]
784 .cast<FieldExpression>()
785 ->dt->get_compute_type());
786 } else if (is_ndarray()) { // ndarray
787 auto external_tensor_expr = var.cast<ExternalTensorExpression>();
788 int total_dim = external_tensor_expr->dim;
789 int element_dim = external_tensor_expr->dt.get_shape().size();
790 if (total_dim != index_dim + element_dim) {
791 throw TaichiTypeError(
792 fmt::format("Array with dim {} accessed with indices of dim {}",
793 total_dim - element_dim, index_dim));
794 }
795
796 if (index_dim == total_dim) {
797 // Access all the way to a single element
798 ret_type = var.cast<ExternalTensorExpression>()->dt.get_element_type();
799 } else {
800 // Access to a Tensor
801 ret_type = var.cast<ExternalTensorExpression>()->dt;
802 }
803 } else if (is_tensor()) { // local tensor
804 auto shape = var->ret_type->as<TensorType>()->get_shape();
805 if (indices_group[0].size() != shape.size()) {
806 TI_ERROR("Expected {} indices, got {}.", shape.size(),
807 indices_group[0].size());
808 }
809 ret_type = var->ret_type->cast<TensorType>()->get_element_type();
810 } else {
811 throw TaichiTypeError(
812 "Invalid IndexExpression: the source is not among field, ndarray or "
813 "local tensor");
814 }
815
816 for (auto &indices : indices_group) {
817 for (int i = 0; i < indices.exprs.size(); i++) {
818 auto &expr = indices.exprs[i];
819 TI_ASSERT_TYPE_CHECKED(expr);
820 if (!is_integral(expr->ret_type))
821 throw TaichiTypeError(
822 fmt::format("indices must be integers, however '{}' is "
823 "provided as index {}",
824 expr->ret_type->to_string(), i));
825 }
826 }
827}
828
829void IndexExpression::flatten(FlattenContext *ctx) {
830 if (is_field()) {
831 stmt =
832 make_field_access(ctx, *var.cast<FieldExpression>(), indices_group[0]);
833 } else if (is_matrix_field()) {
834 stmt = make_matrix_field_access(ctx, *var.cast<MatrixFieldExpression>(),
835 indices_group[0], ret_type);
836 } else if (is_ndarray()) {
837 stmt = make_ndarray_access(ctx, var, indices_group[0]);
838 } else if (is_tensor()) {
839 stmt =
840 make_tensor_access(ctx, var, indices_group, ret_type,
841 var->ret_type->cast<TensorType>()->get_shape(), tb);
842 } else {
843 throw TaichiTypeError(
844 "Invalid IndexExpression: the source is not among field, ndarray or "
845 "local tensor");
846 }
847 stmt->tb = tb;
848}
849
850void RangeAssumptionExpression::type_check(const CompileConfig *) {
851 TI_ASSERT_TYPE_CHECKED(input);
852 TI_ASSERT_TYPE_CHECKED(base);
853 if (!input->ret_type->is<PrimitiveType>() ||
854 !base->ret_type->is<PrimitiveType>() || input->ret_type != base->ret_type)
855 throw TaichiTypeError(
856 fmt::format("unsupported operand type(s) for "
857 "'range_assumption': '{}' and '{}'",
858 input->ret_type->to_string(), base->ret_type->to_string()));
859 ret_type = input->ret_type;
860}
861
862void RangeAssumptionExpression::flatten(FlattenContext *ctx) {
863 auto input_stmt = flatten_rvalue(input, ctx);
864 auto base_stmt = flatten_rvalue(base, ctx);
865 ctx->push_back(
866 Stmt::make<RangeAssumptionStmt>(input_stmt, base_stmt, low, high));
867 stmt = ctx->back_stmt();
868}
869
870void LoopUniqueExpression::type_check(const CompileConfig *) {
871 TI_ASSERT_TYPE_CHECKED(input);
872 if (!input->ret_type->is<PrimitiveType>())
873 throw TaichiTypeError(
874 fmt::format("unsupported operand type(s) for 'loop_unique': '{}'",
875 input->ret_type->to_string()));
876 ret_type = input->ret_type;
877}
878
879void LoopUniqueExpression::flatten(FlattenContext *ctx) {
880 auto input_stmt = flatten_rvalue(input, ctx);
881 ctx->push_back(Stmt::make<LoopUniqueStmt>(input_stmt, covers));
882 stmt = ctx->back_stmt();
883}
884
885void IdExpression::flatten(FlattenContext *ctx) {
886 stmt = ctx->current_block->lookup_var(id);
887 if (!ret_type->is_primitive(PrimitiveTypeID::unknown)) {
888 stmt->ret_type = ret_type;
889 }
890}
891
892void AtomicOpExpression::type_check(const CompileConfig *config) {
893 TI_ASSERT_TYPE_CHECKED(dest);
894 TI_ASSERT_TYPE_CHECKED(val);
895 auto error = [&]() {
896 throw TaichiTypeError(fmt::format(
897 "unsupported operand type(s) for 'atomic_{}': '{}' and '{}'",
898 atomic_op_type_name(op_type), dest->ret_type->to_string(),
899 val->ret_type->to_string()));
900 };
901
902 // Broadcast val to dest if neccessary
903 auto val_dtype = val->ret_type;
904 auto dest_dtype = dest->ret_type.ptr_removed();
905 if (dest_dtype->is<PrimitiveType>() && val_dtype->is<TensorType>()) {
906 error();
907 }
908
909 if (val_dtype->is<PrimitiveType>() && dest_dtype->is<TensorType>()) {
910 auto broadcasted_expr = to_broadcast_tensor(val, dest_dtype);
911 val = std::move(broadcasted_expr);
912 val.type_check(config);
913 }
914
915 // Validate dtype
916 auto dtype = val->ret_type;
917 if (dtype->is<TensorType>()) {
918 dtype = dtype.get_element_type();
919 }
920
921 if (!dtype->is<PrimitiveType>()) {
922 error();
923 }
924
925 if (is_quant(dest->ret_type)) {
926 ret_type = dest->ret_type->get_compute_type();
927 } else if (dest->ret_type->is<PrimitiveType>() ||
928 dest->ret_type->is<TensorType>()) {
929 ret_type = dest->ret_type;
930 } else {
931 error();
932 }
933}
934
935void AtomicOpExpression::flatten(FlattenContext *ctx) {
936 TI_ASSERT(
937 dest.is<IdExpression>() || dest.is<IndexExpression>() ||
938 (dest.is<ArgLoadExpression>() && dest.cast<ArgLoadExpression>()->is_ptr));
939 // replace atomic sub with negative atomic add
940 if (op_type == AtomicOpType::sub) {
941 if (val->ret_type != ret_type) {
942 val.set(Expr::make<UnaryOpExpression>(UnaryOpType::cast_value, val,
943 ret_type));
944 }
945
946 val.set(Expr::make<UnaryOpExpression>(UnaryOpType::neg, val));
947 op_type = AtomicOpType::add;
948 }
949 // expand rhs
950 auto val_stmt = flatten_rvalue(val, ctx);
951 auto dest_stmt = flatten_lvalue(dest, ctx);
952 stmt = ctx->push_back<AtomicOpStmt>(op_type, dest_stmt, val_stmt);
953 stmt->ret_type = stmt->as<AtomicOpStmt>()->dest->ret_type;
954 stmt->tb = tb;
955}
956
957SNodeOpExpression::SNodeOpExpression(SNode *snode,
958 SNodeOpType op_type,
959 const ExprGroup &indices)
960 : snode(snode), op_type(op_type), indices(indices) {
961}
962
963SNodeOpExpression::SNodeOpExpression(SNode *snode,
964 SNodeOpType op_type,
965 const ExprGroup &indices,
966 const std::vector<Expr> &values)
967 : SNodeOpExpression(snode, op_type, indices) {
968 this->values = values;
969}
970
971void SNodeOpExpression::type_check(const CompileConfig *config) {
972 if (op_type == SNodeOpType::get_addr) {
973 ret_type = PrimitiveType::u64;
974 } else {
975 ret_type = PrimitiveType::i32;
976 }
977 if (op_type == SNodeOpType::append) {
978 TI_ASSERT(snode->ch.size() == values.size());
979 for (int i = 0; i < values.size(); i++) {
980 TI_ASSERT_TYPE_CHECKED(values[i]);
981 auto &dst_type = snode->ch[i]->dt;
982 auto promoted = promoted_type(dst_type, values[i]->ret_type);
983 if (dst_type != promoted) {
984 TI_WARN("Append may lose precision: {} <- {}\n{}",
985 dst_type->to_string(), values[i]->ret_type->to_string(), tb);
986 }
987 values[i] = cast(values[i], dst_type);
988 values[i]->type_check(config);
989 }
990 }
991}
992
993void SNodeOpExpression::flatten(FlattenContext *ctx) {
994 std::vector<Stmt *> indices_stmt;
995 for (int i = 0; i < (int)indices.size(); i++) {
996 indices_stmt.push_back(flatten_rvalue(indices[i], ctx));
997 }
998 auto is_cell_access = SNodeOpStmt::activation_related(op_type) &&
999 snode->type != SNodeType::dynamic;
1000 auto ptr =
1001 ctx->push_back<GlobalPtrStmt>(snode, indices_stmt, true, is_cell_access);
1002 ptr->tb = tb;
1003 if (op_type == SNodeOpType::is_active) {
1004 TI_ERROR_IF(snode->type != SNodeType::pointer &&
1005 snode->type != SNodeType::hash &&
1006 snode->type != SNodeType::bitmasked,
1007 "ti.is_active only works on pointer, hash or bitmasked nodes.");
1008 ctx->push_back<SNodeOpStmt>(SNodeOpType::is_active, snode, ptr, nullptr);
1009 } else if (op_type == SNodeOpType::length) {
1010 ctx->push_back<SNodeOpStmt>(SNodeOpType::length, snode, ptr, nullptr);
1011 } else if (op_type == SNodeOpType::get_addr) {
1012 ctx->push_back<SNodeOpStmt>(SNodeOpType::get_addr, snode, ptr, nullptr);
1013 } else if (op_type == SNodeOpType::append) {
1014 auto alloca = ctx->push_back<AllocaStmt>(PrimitiveType::i32);
1015 alloca->set_tb(tb);
1016 auto addr =
1017 ctx->push_back<SNodeOpStmt>(SNodeOpType::allocate, snode, ptr, alloca);
1018 addr->set_tb(tb);
1019 for (int i = 0; i < values.size(); i++) {
1020 auto value_stmt = flatten_rvalue(values[i], ctx);
1021 auto ch_addr = ctx->push_back<GetChStmt>(addr, snode, i);
1022 ch_addr->set_tb(tb);
1023 ctx->push_back<GlobalStoreStmt>(ch_addr, value_stmt)->set_tb(tb);
1024 }
1025 ctx->push_back<LocalLoadStmt>(alloca)->set_tb(tb);
1026 TI_ERROR_IF(snode->type != SNodeType::dynamic,
1027 "ti.append only works on dynamic nodes.");
1028 }
1029 stmt = ctx->back_stmt();
1030}
1031
1032TextureOpExpression::TextureOpExpression(TextureOpType op,
1033 Expr texture_ptr,
1034 const ExprGroup &args)
1035 : op(op), texture_ptr(texture_ptr), args(args) {
1036}
1037
1038void TextureOpExpression::type_check(const CompileConfig *config) {
1039 TI_ASSERT(texture_ptr.is<TexturePtrExpression>());
1040 auto ptr = texture_ptr.cast<TexturePtrExpression>();
1041 if (op == TextureOpType::kSampleLod) {
1042 // UV, Lod
1043 TI_ASSERT_INFO(args.size() == ptr->num_dims + 1,
1044 "Invalid number of args for sample_lod Texture op with a "
1045 "{}-dimension texture",
1046 ptr->num_dims);
1047 for (int i = 0; i < ptr->num_dims; i++) {
1048 TI_ASSERT_TYPE_CHECKED(args[i]);
1049 if (args[i].get_ret_type() != PrimitiveType::f32) {
1050 throw TaichiTypeError(
1051 fmt::format("Invalid type for texture sample_lod: '{}', all "
1052 "arguments must be f32",
1053 args[i].get_ret_type()->to_string()));
1054 }
1055 }
1056 } else if (op == TextureOpType::kFetchTexel) {
1057 // index, int LOD
1058 TI_ASSERT_INFO(args.size() == ptr->num_dims + 1,
1059 "Invalid number of args for fetch_texel Texture op with a "
1060 "{}-dimension texture",
1061 ptr->num_dims);
1062 for (int i = 0; i < ptr->num_dims; i++) {
1063 TI_ASSERT_TYPE_CHECKED(args[i]);
1064 if (args[i].get_ret_type() != PrimitiveType::i32) {
1065 throw TaichiTypeError(
1066 fmt::format("Invalid type for texture fetch_texel: '{}', all "
1067 "arguments must be i32",
1068 args[i].get_ret_type()->to_string()));
1069 }
1070 }
1071 } else if (op == TextureOpType::kLoad) {
1072 // index
1073 TI_ASSERT_INFO(args.size() == ptr->num_dims,
1074 "Invalid number of args for load Texture op with a "
1075 "{}-dimension texture",
1076 ptr->num_dims);
1077 for (int i = 0; i < ptr->num_dims; i++) {
1078 TI_ASSERT_TYPE_CHECKED(args[i]);
1079 if (args[i].get_ret_type() != PrimitiveType::i32) {
1080 throw TaichiTypeError(
1081 fmt::format("Invalid type for texture load: '{}', all "
1082 "arguments must be i32",
1083 args[i].get_ret_type()->to_string()));
1084 }
1085 }
1086 } else if (op == TextureOpType::kStore) {
1087 // index, value
1088 TI_ASSERT_INFO(args.size() == ptr->num_dims + 4,
1089 "Invalid number of args for store Texture op with a "
1090 "{}-dimension texture",
1091 ptr->num_dims);
1092 for (int i = 0; i < ptr->num_dims; i++) {
1093 TI_ASSERT_TYPE_CHECKED(args[i]);
1094 if (args[i].get_ret_type() != PrimitiveType::i32) {
1095 throw TaichiTypeError(
1096 fmt::format("Invalid type for texture load: '{}', index "
1097 "arguments must be i32",
1098 args[i].get_ret_type()->to_string()));
1099 }
1100 }
1101 for (int i = ptr->num_dims; i < ptr->num_dims + 4; i++) {
1102 TI_ASSERT_TYPE_CHECKED(args[i]);
1103 if (args[i].get_ret_type() != PrimitiveType::f32) {
1104 throw TaichiTypeError(
1105 fmt::format("Invalid type for texture load: '{}', value "
1106 "arguments must be f32",
1107 args[i].get_ret_type()->to_string()));
1108 }
1109 }
1110 } else {
1111 TI_ERROR("Invalid TextureOpType");
1112 }
1113 ret_type =
1114 TypeFactory::get_instance().get_pointer_type(PrimitiveType::f32,
1115 /*is_bit_pointer=*/false);
1116}
1117
1118void TextureOpExpression::flatten(FlattenContext *ctx) {
1119 auto texture_ptr_stmt = flatten_rvalue(texture_ptr, ctx);
1120 std::vector<Stmt *> arg_stmts;
1121 for (Expr &arg : args.exprs) {
1122 arg_stmts.push_back(flatten_rvalue(arg, ctx));
1123 }
1124 ctx->push_back<TextureOpStmt>(op, texture_ptr_stmt, arg_stmts);
1125 stmt = ctx->back_stmt();
1126}
1127
1128void ConstExpression::type_check(const CompileConfig *) {
1129 TI_ASSERT_INFO(
1130 val.dt->is<PrimitiveType>() && val.dt != PrimitiveType::unknown,
1131 "Invalid dt [{}] for ConstExpression", val.dt->to_string());
1132 ret_type = val.dt;
1133}
1134
1135void ConstExpression::flatten(FlattenContext *ctx) {
1136 ctx->push_back(Stmt::make<ConstStmt>(val));
1137 stmt = ctx->back_stmt();
1138}
1139
1140void ExternalTensorShapeAlongAxisExpression::type_check(const CompileConfig *) {
1141 TI_ASSERT_INFO(
1142 ptr.is<ExternalTensorExpression>() || ptr.is<TexturePtrExpression>(),
1143 "Invalid ptr [{}] for ExternalTensorShapeAlongAxisExpression",
1144 ExpressionHumanFriendlyPrinter::expr_to_string(ptr));
1145 ret_type = PrimitiveType::i32;
1146}
1147
1148void ExternalTensorShapeAlongAxisExpression::flatten(FlattenContext *ctx) {
1149 auto temp = ptr.cast<ExternalTensorExpression>();
1150 TI_ASSERT(0 <= axis && axis < temp->dim);
1151 ctx->push_back<ExternalTensorShapeAlongAxisStmt>(axis, temp->arg_id);
1152 stmt = ctx->back_stmt();
1153}
1154
1155void GetElementExpression::type_check(const CompileConfig *config) {
1156 TI_ASSERT_TYPE_CHECKED(src);
1157
1158 ret_type = src->ret_type->as<StructType>()->get_element_type(index);
1159}
1160
1161void GetElementExpression::flatten(FlattenContext *ctx) {
1162 ctx->push_back<GetElementStmt>(flatten_rvalue(src, ctx), index);
1163 stmt = ctx->back_stmt();
1164}
1165// Mesh related.
1166
1167void MeshPatchIndexExpression::flatten(FlattenContext *ctx) {
1168 auto pid_stmt = std::make_unique<MeshPatchIndexStmt>();
1169 ctx->push_back(std::move(pid_stmt));
1170 stmt = ctx->back_stmt();
1171}
1172
1173void MeshPatchIndexExpression::type_check(const CompileConfig *) {
1174 ret_type = PrimitiveType::i32;
1175}
1176
1177void MeshRelationAccessExpression::type_check(const CompileConfig *) {
1178 ret_type = PrimitiveType::i32;
1179}
1180
1181void MeshRelationAccessExpression::flatten(FlattenContext *ctx) {
1182 auto mesh_idx_stmt = flatten_rvalue(mesh_idx, ctx);
1183 if (neighbor_idx) {
1184 auto neighbor_idx_stmt = flatten_rvalue(neighbor_idx, ctx);
1185 ctx->push_back<MeshRelationAccessStmt>(mesh, mesh_idx_stmt, to_type,
1186 neighbor_idx_stmt);
1187 } else {
1188 ctx->push_back<MeshRelationAccessStmt>(mesh, mesh_idx_stmt, to_type);
1189 }
1190 stmt = ctx->back_stmt();
1191}
1192
1193MeshIndexConversionExpression::MeshIndexConversionExpression(
1194 mesh::Mesh *mesh,
1195 mesh::MeshElementType idx_type,
1196 const Expr idx,
1197 mesh::ConvType conv_type)
1198 : mesh(mesh), idx_type(idx_type), idx(idx), conv_type(conv_type) {
1199}
1200
1201void MeshIndexConversionExpression::type_check(const CompileConfig *) {
1202 ret_type = PrimitiveType::i32;
1203}
1204
1205void MeshIndexConversionExpression::flatten(FlattenContext *ctx) {
1206 auto idx_stmt = flatten_rvalue(idx, ctx);
1207 ctx->push_back<MeshIndexConversionStmt>(mesh, idx_type, idx_stmt, conv_type);
1208 stmt = ctx->back_stmt();
1209}
1210
1211void ReferenceExpression::type_check(const CompileConfig *) {
1212 ret_type = var->ret_type;
1213}
1214
1215void ReferenceExpression::flatten(FlattenContext *ctx) {
1216 auto var_stmt = flatten_lvalue(var, ctx);
1217 ctx->push_back<ReferenceStmt>(var_stmt);
1218 stmt = ctx->back_stmt();
1219}
1220
1221Block *ASTBuilder::current_block() {
1222 if (stack_.empty())
1223 return nullptr;
1224 else
1225 return stack_.back();
1226}
1227
1228Stmt *ASTBuilder::get_last_stmt() {
1229 TI_ASSERT(!stack_.empty());
1230 return stack_.back()->back();
1231}
1232
1233void ASTBuilder::insert(std::unique_ptr<Stmt> &&stmt, int location) {
1234 TI_ASSERT(!stack_.empty());
1235 stack_.back()->insert(std::move(stmt), location);
1236}
1237
1238void ASTBuilder::stop_gradient(SNode *snode) {
1239 TI_ASSERT(!stack_.empty());
1240 stack_.back()->stop_gradients.push_back(snode);
1241}
1242
1243void ASTBuilder::insert_assignment(Expr &lhs,
1244 const Expr &rhs,
1245 const std::string &tb) {
1246 // Inside a kernel or a function
1247 // Create an assignment in the IR
1248 if (lhs.expr == nullptr) {
1249 lhs.set(rhs);
1250 } else if (lhs.expr->is_lvalue()) {
1251 auto stmt = std::make_unique<FrontendAssignStmt>(lhs, rhs);
1252 stmt->tb = tb;
1253 this->insert(std::move(stmt));
1254
1255 } else {
1256 TI_ERROR("Cannot assign to non-lvalue: {}",
1257 ExpressionHumanFriendlyPrinter::expr_to_string(lhs));
1258 }
1259}
1260
1261Expr ASTBuilder::make_var(const Expr &x, std::string tb) {
1262 auto var = this->expr_alloca();
1263 this->insert_assignment(var, x, tb);
1264 return var;
1265}
1266
1267Expr ASTBuilder::make_id_expr(const std::string &name) {
1268 return Expr::make<IdExpression>(get_next_id(name));
1269}
1270
1271void ASTBuilder::insert_for(const Expr &s,
1272 const Expr &e,
1273 const std::function<void(Expr)> &func) {
1274 auto i = Expr(std::make_shared<IdExpression>(get_next_id()));
1275 auto stmt_unique = std::make_unique<FrontendForStmt>(i, s, e, this->arch_,
1276 for_loop_dec_.config);
1277 for_loop_dec_.reset();
1278 auto stmt = stmt_unique.get();
1279 this->insert(std::move(stmt_unique));
1280 this->create_scope(stmt->body);
1281 func(i);
1282 this->pop_scope();
1283}
1284
1285Expr ASTBuilder::insert_thread_idx_expr() {
1286 auto loop = stack_.size() ? stack_.back()->parent_stmt : nullptr;
1287 TI_ERROR_IF(
1288 arch_ != Arch::cuda && !arch_is_cpu(arch_) && arch_ != Arch::amdgpu,
1289 "ti.thread_idx() is only available in cuda or cpu or amdgpu context.");
1290 if (loop != nullptr) {
1291 auto i = stack_.size() - 1;
1292 while (!(loop->is<FrontendForStmt>())) {
1293 loop = i > 0 ? stack_[--i]->parent_stmt : nullptr;
1294 if (loop == nullptr)
1295 break;
1296 }
1297 }
1298 TI_ERROR_IF(!(loop && loop->is<FrontendForStmt>()),
1299 "ti.thread_idx() is only valid within loops.");
1300 return Expr::make<InternalFuncCallExpression>(
1301 "linear_thread_idx", std::vector<Expr>{}, /*with_runtime_context=*/true);
1302}
1303
1304Expr ASTBuilder::insert_patch_idx_expr() {
1305 auto loop = stack_.size() ? stack_.back()->parent_stmt : nullptr;
1306 if (loop != nullptr) {
1307 auto i = stack_.size() - 1;
1308 while (!(loop->is<FrontendForStmt>())) {
1309 loop = i > 0 ? stack_[--i]->parent_stmt : nullptr;
1310 if (loop == nullptr)
1311 break;
1312 }
1313 }
1314 TI_ERROR_IF(!(loop && loop->is<FrontendForStmt>() &&
1315 loop->as<FrontendForStmt>()->mesh),
1316 "ti.mesh_patch_idx() is only valid within mesh-for loops.");
1317 return Expr::make<MeshPatchIndexExpression>();
1318}
1319
1320void ASTBuilder::create_kernel_exprgroup_return(const ExprGroup &group) {
1321 auto expanded_exprs = this->expand_exprs(group.exprs);
1322 ExprGroup expanded_expr_group;
1323 expanded_expr_group.exprs = std::move(expanded_exprs);
1324 this->insert(Stmt::make<FrontendReturnStmt>(expanded_expr_group));
1325}
1326
1327void ASTBuilder::create_print(
1328 std::vector<std::variant<Expr, std::string>> contents) {
1329 this->insert(std::make_unique<FrontendPrintStmt>(contents));
1330}
1331
1332void ASTBuilder::begin_func(const std::string &funcid) {
1333 auto stmt_unique = std::make_unique<FrontendFuncDefStmt>(funcid);
1334 auto stmt = stmt_unique.get();
1335 this->insert(std::move(stmt_unique));
1336 this->create_scope(stmt->body);
1337}
1338
1339void ASTBuilder::end_func(const std::string &funcid) {
1340 this->pop_scope();
1341}
1342
1343void ASTBuilder::begin_frontend_if(const Expr &cond) {
1344 auto stmt_tmp = std::make_unique<FrontendIfStmt>(cond);
1345 this->insert(std::move(stmt_tmp));
1346}
1347
1348void ASTBuilder::begin_frontend_if_true() {
1349 auto if_stmt = this->get_last_stmt()->as<FrontendIfStmt>();
1350 this->create_scope(if_stmt->true_statements);
1351}
1352
1353void ASTBuilder::begin_frontend_if_false() {
1354 auto if_stmt = this->get_last_stmt()->as<FrontendIfStmt>();
1355 this->create_scope(if_stmt->false_statements);
1356}
1357
1358void ASTBuilder::insert_external_func_call(std::size_t func_addr,
1359 std::string source,
1360 std::string filename,
1361 std::string funcname,
1362 const ExprGroup &args,
1363 const ExprGroup &outputs) {
1364 auto stmt = Stmt::make<FrontendExternalFuncStmt>(
1365 (void *)func_addr, source, filename, funcname, args.exprs, outputs.exprs);
1366 this->insert(std::move(stmt));
1367}
1368
1369Expr ASTBuilder::expr_alloca() {
1370 auto var = Expr(std::make_shared<IdExpression>(get_next_id()));
1371 this->insert(std::make_unique<FrontendAllocaStmt>(
1372 std::static_pointer_cast<IdExpression>(var.expr)->id,
1373 PrimitiveType::unknown));
1374 return var;
1375}
1376
1377std::optional<Expr> ASTBuilder::insert_func_call(Function *func,
1378 const ExprGroup &args) {
1379 ExprGroup expanded_args;
1380 expanded_args.exprs = this->expand_exprs(args.exprs);
1381 if (func->ret_type) {
1382 auto var = Expr(std::make_shared<IdExpression>(get_next_id()));
1383 this->insert(std::make_unique<FrontendFuncCallStmt>(
1384 func, expanded_args,
1385 std::static_pointer_cast<IdExpression>(var.expr)->id));
1386 var.expr->ret_type = func->ret_type;
1387 return var;
1388 } else {
1389 this->insert(std::make_unique<FrontendFuncCallStmt>(func, expanded_args));
1390 return std::nullopt;
1391 }
1392}
1393
1394Expr ASTBuilder::make_matrix_expr(const std::vector<int> &shape,
1395 const DataType &dt,
1396 const std::vector<Expr> &elements) {
1397 /*
1398 Since we have both "shape" and "element_type" in MatrixExpression,
1399 we should flatten all the elements and disallow recursive TensorType in
1400 element Expr
1401 */
1402 TI_ASSERT(dt->is<PrimitiveType>());
1403 auto expanded_elements = this->expand_exprs(elements);
1404 auto mat =
1405 Expr(std::make_shared<MatrixExpression>(expanded_elements, shape, dt));
1406 return mat;
1407}
1408
1409Expr ASTBuilder::expr_alloca_shared_array(const std::vector<int> &shape,
1410 const DataType &element_type) {
1411 auto var = Expr(std::make_shared<IdExpression>(get_next_id()));
1412 this->insert(std::make_unique<FrontendAllocaStmt>(
1413 std::static_pointer_cast<IdExpression>(var.expr)->id, shape, element_type,
1414 true));
1415 var->ret_type = this->get_last_stmt()->ret_type;
1416 return var;
1417}
1418
1419void ASTBuilder::expr_assign(const Expr &lhs, const Expr &rhs, std::string tb) {
1420 TI_ASSERT(lhs->is_lvalue());
1421 auto stmt = std::make_unique<FrontendAssignStmt>(lhs, rhs);
1422 stmt->set_tb(tb);
1423 this->insert(std::move(stmt));
1424}
1425
1426Expr ASTBuilder::expr_subscript(const Expr &expr,
1427 const ExprGroup &indices,
1428 std::string tb) {
1429 TI_ASSERT(expr.is<FieldExpression>() || expr.is<MatrixFieldExpression>() ||
1430 expr.is<ExternalTensorExpression>() ||
1431 is_tensor(expr.expr->ret_type));
1432
1433 // IndexExpression without ret_shape is used for matrix indexing,
1434 // where each entry of ExprGroup is interpreted as indexing into a specific
1435 // axis. For example, mat[3, 4] has indices_group={[3, 4]}, where [3, 4]
1436 // corresponds to "n"-axis and "m"-axis of the matrix. Therefore we expand
1437 // indices_group={[3, 4]} into {3, 4} to avoid TensorType in indices.
1438 std::vector<Expr> expanded_indices = this->expand_exprs(indices.exprs);
1439 auto expanded_expr_group = ExprGroup();
1440 expanded_expr_group.exprs = expanded_indices;
1441
1442 return Expr::make<IndexExpression>(expr, expanded_expr_group, tb);
1443}
1444
1445void ASTBuilder::create_assert_stmt(const Expr &cond,
1446 const std::string &msg,
1447 const std::vector<Expr> &args) {
1448 auto stmt_unique = std::make_unique<FrontendAssertStmt>(cond, msg, args);
1449 this->insert(std::move(stmt_unique));
1450}
1451
1452void ASTBuilder::begin_frontend_range_for(const Expr &i,
1453 const Expr &s,
1454 const Expr &e) {
1455 auto stmt_unique =
1456 std::make_unique<FrontendForStmt>(i, s, e, arch_, for_loop_dec_.config);
1457 auto stmt = stmt_unique.get();
1458 this->insert(std::move(stmt_unique));
1459 this->create_scope(stmt->body,
1460 for_loop_dec_.config.strictly_serialized ? While : For);
1461 for_loop_dec_.reset();
1462}
1463
1464void ASTBuilder::begin_frontend_struct_for_on_snode(const ExprGroup &loop_vars,
1465 SNode *snode) {
1466 TI_WARN_IF(
1467 for_loop_dec_.config.strictly_serialized,
1468 "ti.loop_config(serialize=True) does not have effect on the struct for. "
1469 "The execution order is not guaranteed.");
1470 auto stmt_unique = std::make_unique<FrontendForStmt>(loop_vars, snode, arch_,
1471 for_loop_dec_.config);
1472 for_loop_dec_.reset();
1473 auto stmt = stmt_unique.get();
1474 this->insert(std::move(stmt_unique));
1475 this->create_scope(stmt->body, For);
1476}
1477
1478void ASTBuilder::begin_frontend_struct_for_on_external_tensor(
1479 const ExprGroup &loop_vars,
1480 const Expr &external_tensor) {
1481 TI_WARN_IF(
1482 for_loop_dec_.config.strictly_serialized,
1483 "ti.loop_config(serialize=True) does not have effect on the struct for. "
1484 "The execution order is not guaranteed.");
1485 auto stmt_unique = std::make_unique<FrontendForStmt>(
1486 loop_vars, external_tensor, arch_, for_loop_dec_.config);
1487 for_loop_dec_.reset();
1488 auto stmt = stmt_unique.get();
1489 this->insert(std::move(stmt_unique));
1490 this->create_scope(stmt->body, For);
1491}
1492
1493void ASTBuilder::begin_frontend_mesh_for(
1494 const Expr &i,
1495 const mesh::MeshPtr &mesh_ptr,
1496 const mesh::MeshElementType &element_type) {
1497 TI_WARN_IF(
1498 for_loop_dec_.config.strictly_serialized,
1499 "ti.loop_config(serialize=True) does not have effect on the mesh for. "
1500 "The execution order is not guaranteed.");
1501 auto stmt_unique = std::make_unique<FrontendForStmt>(
1502 ExprGroup(i), mesh_ptr, element_type, arch_, for_loop_dec_.config);
1503 for_loop_dec_.reset();
1504 auto stmt = stmt_unique.get();
1505 this->insert(std::move(stmt_unique));
1506 this->create_scope(stmt->body, For);
1507}
1508
1509void ASTBuilder::begin_frontend_while(const Expr &cond) {
1510 auto stmt_unique = std::make_unique<FrontendWhileStmt>(cond);
1511 auto stmt = stmt_unique.get();
1512 this->insert(std::move(stmt_unique));
1513 this->create_scope(stmt->body, While);
1514}
1515
1516void ASTBuilder::insert_break_stmt() {
1517 if (loop_state_stack_.back() == Outermost) {
1518 throw TaichiSyntaxError("Cannot break in the outermost loop");
1519 }
1520 this->insert(Stmt::make<FrontendBreakStmt>());
1521}
1522
1523void ASTBuilder::insert_continue_stmt() {
1524 this->insert(Stmt::make<FrontendContinueStmt>());
1525}
1526
1527void ASTBuilder::insert_expr_stmt(const Expr &val) {
1528 this->insert(Stmt::make<FrontendExprStmt>(val));
1529}
1530
1531void ASTBuilder::insert_snode_activate(SNode *snode,
1532 const ExprGroup &expr_group) {
1533 ExprGroup expanded_group;
1534 expanded_group.exprs = this->expand_exprs(expr_group.exprs);
1535 this->insert(Stmt::make<FrontendSNodeOpStmt>(SNodeOpType::activate, snode,
1536 expanded_group));
1537}
1538
1539void ASTBuilder::insert_snode_deactivate(SNode *snode,
1540 const ExprGroup &expr_group) {
1541 ExprGroup expanded_group;
1542 expanded_group.exprs = this->expand_exprs(expr_group.exprs);
1543 this->insert(Stmt::make<FrontendSNodeOpStmt>(SNodeOpType::deactivate, snode,
1544 expanded_group));
1545}
1546
1547Expr ASTBuilder::snode_append(SNode *snode,
1548 const ExprGroup &indices,
1549 const std::vector<Expr> &vals) {
1550 ExprGroup expanded_exprs;
1551 expanded_exprs.exprs = this->expand_exprs(indices.exprs);
1552 std::vector<Expr> expanded_vals = this->expand_exprs(vals);
1553 return Expr::make<SNodeOpExpression>(snode, SNodeOpType::append,
1554 expanded_exprs, expanded_vals);
1555}
1556
1557Expr ASTBuilder::snode_is_active(SNode *snode, const ExprGroup &indices) {
1558 ExprGroup expanded_exprs;
1559 expanded_exprs.exprs = this->expand_exprs(indices.exprs);
1560 return Expr::make<SNodeOpExpression>(snode, SNodeOpType::is_active,
1561 expanded_exprs);
1562}
1563
1564Expr ASTBuilder::snode_length(SNode *snode, const ExprGroup &indices) {
1565 ExprGroup expanded_exprs;
1566 expanded_exprs.exprs = this->expand_exprs(indices.exprs);
1567 return Expr::make<SNodeOpExpression>(snode, SNodeOpType::length,
1568 expanded_exprs);
1569}
1570
1571Expr ASTBuilder::snode_get_addr(SNode *snode, const ExprGroup &indices) {
1572 ExprGroup expanded_exprs;
1573 expanded_exprs.exprs = this->expand_exprs(indices.exprs);
1574 return Expr::make<SNodeOpExpression>(snode, SNodeOpType::get_addr,
1575 expanded_exprs);
1576}
1577
1578std::vector<Expr> ASTBuilder::expand_exprs(const std::vector<Expr> &exprs) {
1579 if (exprs.size() == 0) {
1580 return exprs;
1581 }
1582
1583 std::vector<Expr> expanded_exprs;
1584 for (auto expr : exprs) {
1585 TI_ASSERT_TYPE_CHECKED(expr);
1586 if (!expr->ret_type->is<TensorType>()) {
1587 expanded_exprs.push_back(expr);
1588 } else {
1589 // Expand TensorType expr
1590 /*
1591 Before:
1592 TensorType<4 x i32> index = Expr;
1593
1594 After:
1595 TensorType<4 x i32>* id_expr = FrontendAllocaStmt(TensorType<4 x i32>)
1596 i32 ind0 = IndexExpression(id_expr, 0)
1597 i32 ind1 = IndexExpression(id_expr, 1)
1598 i32 ind2 = IndexExpression(id_expr, 2)
1599 i32 ind3 = IndexExpression(id_expr, 3)
1600
1601 return {ind0, ind1, ind2, ind3}
1602
1603 */
1604 auto tensor_type = expr->ret_type->cast<TensorType>();
1605
1606 Expr id_expr;
1607 if (expr.is<IdExpression>()) {
1608 id_expr = expr;
1609 } else {
1610 id_expr = make_var(expr, expr->tb);
1611 }
1612 auto shape = tensor_type->get_shape();
1613 if (shape.size() == 1) {
1614 for (int i = 0; i < shape[0]; i++) {
1615 auto ind = Expr(std::make_shared<IndexExpression>(
1616 id_expr, ExprGroup(Expr(i)), expr->tb));
1617 ind.expr->ret_type = tensor_type->get_element_type();
1618 expanded_exprs.push_back(ind);
1619 }
1620 } else {
1621 TI_ASSERT(shape.size() == 2);
1622 for (int i = 0; i < shape[0]; i++) {
1623 for (int j = 0; j < shape[1]; j++) {
1624 auto ind = Expr(std::make_shared<IndexExpression>(
1625 id_expr, ExprGroup(Expr(i), Expr(j)), expr->tb));
1626 ind.expr->ret_type = tensor_type->get_element_type();
1627 expanded_exprs.push_back(ind);
1628 }
1629 }
1630 }
1631 }
1632 }
1633
1634 return expanded_exprs;
1635}
1636
1637Expr ASTBuilder::mesh_index_conversion(mesh::MeshPtr mesh_ptr,
1638 mesh::MeshElementType idx_type,
1639 const Expr &idx,
1640 mesh::ConvType &conv_type) {
1641 Expr expanded_idx;
1642 if (idx.is<IdExpression>() && idx.get_ret_type() == PrimitiveType::unknown) {
1643 expanded_idx = idx;
1644 } else {
1645 if (idx.expr->ret_type->is<TensorType>()) {
1646 TI_ASSERT(idx.expr->ret_type->cast<TensorType>()->get_num_elements() ==
1647 1);
1648 }
1649 expanded_idx = this->expand_exprs({idx})[0];
1650 }
1651
1652 return Expr::make<MeshIndexConversionExpression>(mesh_ptr.ptr.get(), idx_type,
1653 expanded_idx, conv_type);
1654}
1655
1656void ASTBuilder::create_scope(std::unique_ptr<Block> &list, LoopType tp) {
1657 TI_ASSERT(list == nullptr);
1658 LoopState prev = loop_state_stack_.back();
1659 if (tp == NotLoop) {
1660 loop_state_stack_.push_back(prev);
1661 } else if (tp == For && stack_.size() == 1) {
1662 loop_state_stack_.push_back(Outermost);
1663 } else {
1664 loop_state_stack_.push_back(Inner);
1665 }
1666 list = std::make_unique<Block>();
1667 if (!stack_.empty()) {
1668 list->parent_stmt = get_last_stmt();
1669 }
1670 stack_.push_back(list.get());
1671}
1672
1673void ASTBuilder::pop_scope() {
1674 stack_.pop_back();
1675 loop_state_stack_.pop_back();
1676}
1677
1678Expr ASTBuilder::make_texture_op_expr(const TextureOpType &op,
1679 const Expr &texture_ptr,
1680 const ExprGroup &args) {
1681 ExprGroup expanded_args;
1682 expanded_args.exprs = this->expand_exprs(args.exprs);
1683 return Expr::make<TextureOpExpression>(op, texture_ptr, expanded_args);
1684}
1685
1686Stmt *flatten_lvalue(Expr expr, Expression::FlattenContext *ctx) {
1687 expr->flatten(ctx);
1688 return expr->get_flattened_stmt();
1689}
1690
1691Stmt *flatten_global_load(Stmt *ptr_stmt, Expression::FlattenContext *ctx) {
1692 ctx->push_back(std::make_unique<GlobalLoadStmt>(ptr_stmt));
1693 return ctx->back_stmt();
1694}
1695
1696Stmt *flatten_local_load(Stmt *ptr_stmt, Expression::FlattenContext *ctx) {
1697 auto local_load = ctx->push_back<LocalLoadStmt>(ptr_stmt);
1698 local_load->ret_type = local_load->src->ret_type.ptr_removed();
1699 return local_load;
1700}
1701
1702Stmt *flatten_rvalue(Expr ptr, Expression::FlattenContext *ctx) {
1703 ptr->flatten(ctx);
1704 Stmt *ptr_stmt = ptr->get_flattened_stmt();
1705 if (ptr.is<IdExpression>()) {
1706 if (ptr_stmt->is<AllocaStmt>()) {
1707 return flatten_local_load(ptr_stmt, ctx);
1708 }
1709 } else if (ptr.is<IndexExpression>()) {
1710 auto ix = ptr.cast<IndexExpression>();
1711 if (ix->is_local()) {
1712 return flatten_local_load(ptr_stmt, ctx);
1713 } else {
1714 return flatten_global_load(ptr_stmt, ctx);
1715 }
1716 } else if (ptr.is<ArgLoadExpression>() &&
1717 ptr.cast<ArgLoadExpression>()->is_ptr) {
1718 return flatten_global_load(ptr_stmt, ctx);
1719 }
1720
1721 return ptr_stmt;
1722}
1723
1724} // namespace taichi::lang
1725