1// Type checking
2
3#include "taichi/ir/ir.h"
4#include "taichi/ir/statements.h"
5#include "taichi/ir/transforms.h"
6#include "taichi/ir/analysis.h"
7#include "taichi/ir/frontend_ir.h"
8#include "taichi/transforms/utils.h"
9
10namespace taichi::lang {
11
12static_assert(
13 sizeof(real) == sizeof(float32),
14 "Please build the taichi compiler with single precision (TI_USE_DOUBLE=0)");
15
16// Var lookup and Type inference
17class TypeCheck : public IRVisitor {
18 private:
19 CompileConfig config_;
20
21 Type *type_check_store(Stmt *stmt,
22 Stmt *dst,
23 Stmt *&val,
24 const std::string &stmt_name) {
25 auto dst_type = dst->ret_type.ptr_removed();
26 if (is_quant(dst_type)) {
27 // We force the value type to be the compute_type of the bit pointer.
28 // Casting from compute_type to physical_type is handled in codegen.
29 dst_type = dst_type->get_compute_type();
30 }
31 if (dst_type != val->ret_type) {
32 auto promoted = promoted_type(dst_type, val->ret_type);
33 if (dst_type != promoted) {
34 TI_WARN("[{}] {} may lose precision: {} <- {}\n{}", stmt->name(),
35 stmt_name, dst_type->to_string(), val->ret_data_type_name(),
36 stmt->tb);
37 }
38 val = insert_type_cast_before(stmt, val, dst_type);
39 }
40 return dst_type;
41 }
42
43 public:
44 explicit TypeCheck(const CompileConfig &config) : config_(config) {
45 allow_undefined_visitor = true;
46 }
47
48 static void mark_as_if_const(Stmt *stmt, DataType t) {
49 if (stmt->is<ConstStmt>()) {
50 stmt->ret_type = t;
51 }
52 }
53
54 void visit(AllocaStmt *stmt) override {
55 // Do nothing. Alloca type is determined by the first LocalStore in IR
56 // visiting order, at compile time.
57
58 // ret_type stands for its element type.
59 }
60
61 void visit(IfStmt *if_stmt) override {
62 if (if_stmt->true_statements)
63 if_stmt->true_statements->accept(this);
64 if (if_stmt->false_statements) {
65 if_stmt->false_statements->accept(this);
66 }
67 }
68
69 void visit(Block *stmt_list) override {
70 std::vector<Stmt *> stmts;
71 // Make a copy since type casts may be inserted for type promotion.
72 for (auto &stmt : stmt_list->statements) {
73 stmts.push_back(stmt.get());
74 }
75 for (auto stmt : stmts)
76 stmt->accept(this);
77 }
78
79 void visit(AtomicOpStmt *stmt) override {
80 // TODO(type): test_ad_for fails if we assume dest is a pointer type.
81 stmt->ret_type = type_check_store(
82 stmt, stmt->dest, stmt->val,
83 fmt::format("Atomic {}", atomic_op_type_name(stmt->op_type)));
84 }
85
86 void visit(LocalLoadStmt *stmt) override {
87 TI_ASSERT(stmt->src->is<AllocaStmt>() || stmt->src->is<MatrixPtrStmt>());
88 if (auto ptr_offset_stmt = stmt->src->cast<MatrixPtrStmt>()) {
89 TI_ASSERT(ptr_offset_stmt->origin->is<AllocaStmt>() ||
90 ptr_offset_stmt->origin->is<GlobalTemporaryStmt>());
91 if (auto alloca_stmt = ptr_offset_stmt->origin->cast<AllocaStmt>()) {
92 auto lookup =
93 DataType(
94 alloca_stmt->ret_type->as<TensorType>()->get_element_type())
95 .ptr_removed();
96 stmt->ret_type = lookup;
97 }
98 if (auto global_temporary_stmt =
99 ptr_offset_stmt->origin->cast<GlobalTemporaryStmt>()) {
100 auto lookup = DataType(global_temporary_stmt->ret_type->as<TensorType>()
101 ->get_element_type())
102 .ptr_removed();
103 stmt->ret_type = lookup;
104 }
105 } else {
106 auto lookup = stmt->src->ret_type;
107 stmt->ret_type = lookup;
108 }
109 }
110
111 void visit(LocalStoreStmt *stmt) override {
112 if (stmt->dest->ret_type->is_primitive(PrimitiveTypeID::unknown)) {
113 // Infer data type for alloca
114 stmt->dest->ret_type = stmt->val->ret_type;
115 }
116 stmt->ret_type =
117 type_check_store(stmt, stmt->dest, stmt->val, "Local store");
118 }
119
120 void visit(GlobalLoadStmt *stmt) override {
121 auto pointee_type = stmt->src->ret_type.ptr_removed();
122 stmt->ret_type = pointee_type->get_compute_type();
123 }
124
125 void visit(SNodeOpStmt *stmt) override {
126 if (stmt->op_type == SNodeOpType::get_addr) {
127 stmt->ret_type = PrimitiveType::u64;
128 } else if (stmt->op_type == SNodeOpType::allocate) {
129 stmt->ret_type = PrimitiveType::gen;
130 stmt->ret_type.set_is_pointer(true);
131 } else {
132 stmt->ret_type = PrimitiveType::i32;
133 }
134 }
135
136 void visit(ExternalTensorShapeAlongAxisStmt *stmt) override {
137 stmt->ret_type = PrimitiveType::i32;
138 }
139
140 void visit(GlobalPtrStmt *stmt) override {
141 if (stmt->is_bit_vectorized) {
142 return;
143 }
144 stmt->ret_type.set_is_pointer(true);
145 if (stmt->snode) {
146 stmt->ret_type =
147 TypeFactory::get_instance().get_pointer_type(stmt->snode->dt);
148 } else
149 TI_WARN("[{}] Type inference failed: snode is nullptr.\n{}", stmt->name(),
150 stmt->tb);
151 auto check_indices = [&](SNode *snode) {
152 if (snode->num_active_indices != stmt->indices.size()) {
153 TI_ERROR("[{}] {} has {} indices. Indexed with {}.", stmt->name(),
154 snode->node_type_name, snode->num_active_indices,
155 stmt->indices.size());
156 }
157 };
158 check_indices(stmt->is_cell_access ? stmt->snode : stmt->snode->parent);
159 for (int i = 0; i < stmt->indices.size(); i++) {
160 if (!stmt->indices[i]->ret_type->is_primitive(PrimitiveTypeID::i32)) {
161 TI_WARN(
162 "[{}] Field index {} not int32, casting into int32 "
163 "implicitly\n{}",
164 stmt->name(), i, stmt->tb);
165 stmt->indices[i] =
166 insert_type_cast_before(stmt, stmt->indices[i], PrimitiveType::i32);
167 }
168 }
169 }
170
171 void visit(MatrixPtrStmt *stmt) override {
172 TI_ASSERT(stmt->offset->ret_type->is_primitive(PrimitiveTypeID::i32));
173 stmt->ret_type.set_is_pointer(true);
174 }
175
176 void visit(GlobalStoreStmt *stmt) override {
177 type_check_store(stmt, stmt->dest, stmt->val, "Global store");
178 }
179
180 void visit(RangeForStmt *stmt) override {
181 mark_as_if_const(stmt->begin, PrimitiveType::i32);
182 mark_as_if_const(stmt->end, PrimitiveType::i32);
183 stmt->body->accept(this);
184 }
185
186 void visit(StructForStmt *stmt) override {
187 stmt->body->accept(this);
188 }
189
190 void visit(MeshForStmt *stmt) override {
191 stmt->body->accept(this);
192 }
193
194 void visit(WhileStmt *stmt) override {
195 stmt->body->accept(this);
196 }
197
198 void visit(UnaryOpStmt *stmt) override {
199 auto operand_type = stmt->operand->ret_type;
200 stmt->ret_type = operand_type;
201 if (stmt->is_cast()) {
202 stmt->ret_type = stmt->cast_type;
203 if (operand_type->is<TensorType>() &&
204 stmt->cast_type->is<PrimitiveType>()) {
205 auto ret_tensor_type = operand_type->as<TensorType>();
206 auto tensor_shape = ret_tensor_type->get_shape();
207 stmt->ret_type = TypeFactory::get_instance().create_tensor_type(
208 tensor_shape, stmt->cast_type);
209 }
210 }
211
212 DataType primitive_dtype = stmt->operand->ret_type.get_element_type();
213 if (!is_real(primitive_dtype)) {
214 if (stmt->op_type == UnaryOpType::sqrt ||
215 stmt->op_type == UnaryOpType::exp ||
216 stmt->op_type == UnaryOpType::log) {
217 DataType target_dtype = config_.default_fp;
218 if (stmt->operand->ret_type->is<TensorType>()) {
219 target_dtype = TypeFactory::get_instance().create_tensor_type(
220 stmt->operand->ret_type->as<TensorType>()->get_shape(),
221 target_dtype);
222 }
223
224 cast(stmt->operand, target_dtype);
225 stmt->ret_type = target_dtype;
226 }
227 }
228 }
229
230 Stmt *insert_type_cast_before(Stmt *anchor,
231 Stmt *input,
232 DataType output_type) {
233 auto &&cast_stmt =
234 Stmt::make_typed<UnaryOpStmt>(UnaryOpType::cast_value, input);
235 cast_stmt->cast_type = output_type;
236 cast_stmt->accept(this);
237 auto stmt = cast_stmt.get();
238 anchor->insert_before_me(std::move(cast_stmt));
239 return stmt;
240 }
241
242 Stmt *insert_type_cast_after(Stmt *anchor,
243 Stmt *input,
244 DataType output_type) {
245 auto &&cast_stmt =
246 Stmt::make_typed<UnaryOpStmt>(UnaryOpType::cast_value, input);
247 cast_stmt->cast_type = output_type;
248 cast_stmt->accept(this);
249 auto stmt = cast_stmt.get();
250 anchor->insert_after_me(std::move(cast_stmt));
251 return stmt;
252 }
253
254 void insert_shift_op_assertion_before(Stmt *stmt, Stmt *lhs, Stmt *rhs) {
255 int rhs_limit = data_type_bits(lhs->ret_type);
256 auto const_stmt =
257 Stmt::make<ConstStmt>(TypedConstant(rhs->ret_type, rhs_limit));
258 auto cond_stmt =
259 Stmt::make<BinaryOpStmt>(BinaryOpType::cmp_le, rhs, const_stmt.get());
260
261 std::string msg =
262 "Detected overflow for bit_shift_op with rhs = %d, exceeding limit of "
263 "%d.";
264 msg += "\n" + stmt->tb;
265 std::vector<Stmt *> args = {rhs, const_stmt.get()};
266 auto assert_stmt =
267 Stmt::make<AssertStmt>(cond_stmt.get(), msg, std::move(args));
268
269 const_stmt->accept(this);
270 cond_stmt->accept(this);
271 assert_stmt->accept(this);
272
273 stmt->insert_before_me(std::move(const_stmt));
274 stmt->insert_before_me(std::move(cond_stmt));
275 stmt->insert_before_me(std::move(assert_stmt));
276 }
277
278 void cast(Stmt *&val, DataType dt) {
279 if (val->ret_type == dt)
280 return;
281
282 auto cast_stmt = insert_type_cast_after(val, val, dt);
283 val = cast_stmt;
284 }
285
286 void visit(BinaryOpStmt *stmt) override {
287 auto error = [&](std::string comment = "") {
288 if (comment == "") {
289 TI_WARN("[{}] Type mismatch (left = {}, right = {}, stmt_id = {})\n{}",
290 stmt->name(), stmt->lhs->ret_data_type_name(),
291 stmt->rhs->ret_data_type_name(), stmt->id, stmt->tb);
292 } else {
293 TI_WARN("[{}] {}\n{}", stmt->name(), comment, stmt->tb);
294 }
295 TI_WARN("Compilation stopped due to type mismatch.");
296 throw std::runtime_error("Binary operator type mismatch");
297 };
298 if (stmt->lhs->ret_type->is_primitive(PrimitiveTypeID::unknown) &&
299 stmt->rhs->ret_type->is_primitive(PrimitiveTypeID::unknown))
300 error();
301 if (stmt->op_type == BinaryOpType::pow &&
302 (is_integral(stmt->rhs->ret_type.get_element_type()))) {
303 stmt->ret_type = stmt->lhs->ret_type;
304 return;
305 }
306
307 auto make_dt = [stmt](DataType dt) {
308 if (auto tensor_ty = stmt->lhs->ret_type->cast<TensorType>()) {
309 return TypeFactory::create_tensor_type(tensor_ty->get_shape(), dt);
310 } else {
311 return dt;
312 }
313 };
314
315 // lower truediv into div
316
317 if (stmt->op_type == BinaryOpType::truediv) {
318 auto default_fp = config_.default_fp;
319 if (!is_real(stmt->lhs->ret_type.get_element_type())) {
320 cast(stmt->lhs, make_dt(default_fp));
321 }
322 if (!is_real(stmt->rhs->ret_type.get_element_type())) {
323 cast(stmt->rhs, make_dt(default_fp));
324 }
325 stmt->op_type = BinaryOpType::div;
326 }
327
328 // Some backends such as vulkan doesn't support fp64
329 // Always promote to fp32 unless necessary
330 if (stmt->op_type == BinaryOpType::atan2) {
331 if (stmt->rhs->ret_type == PrimitiveType::f64 ||
332 stmt->lhs->ret_type == PrimitiveType::f64) {
333 stmt->ret_type = make_dt(PrimitiveType::f64);
334 cast(stmt->rhs, make_dt(PrimitiveType::f64));
335 cast(stmt->lhs, make_dt(PrimitiveType::f64));
336 } else {
337 stmt->ret_type = make_dt(PrimitiveType::f32);
338 cast(stmt->rhs, make_dt(PrimitiveType::f32));
339 cast(stmt->lhs, make_dt(PrimitiveType::f32));
340 }
341 }
342
343 if (stmt->lhs->ret_type != stmt->rhs->ret_type) {
344 DataType ret_type;
345 if (is_shift_op(stmt->op_type)) {
346 // shift_ops does not follow the same type promotion rule as numerical
347 // ops numerical ops: u8 + i32 = i32 shift_ops: u8 << i32 = u8
348 // (return dtype follows that of the lhs)
349 //
350 // In the above example, while truncating rhs(i32) to u8 risks an
351 // overflow, the runtime value of rhs is very likely less than 8
352 // (otherwise meaningless). Nevertheless, we insert an AssertStmt here
353 // to warn user of this potential overflow.
354 ret_type = stmt->lhs->ret_type;
355
356 // Insert AssertStmt
357 if (config_.debug) {
358 insert_shift_op_assertion_before(stmt, stmt->lhs, stmt->rhs);
359 }
360 } else {
361 ret_type = promoted_type(stmt->lhs->ret_type, stmt->rhs->ret_type);
362 }
363
364 if (ret_type != stmt->lhs->ret_type) {
365 // promote lhs
366 auto cast_stmt = insert_type_cast_before(stmt, stmt->lhs, ret_type);
367 stmt->lhs = cast_stmt;
368 }
369 if (ret_type != stmt->rhs->ret_type) {
370 // promote rhs
371 auto cast_stmt = insert_type_cast_before(stmt, stmt->rhs, ret_type);
372 stmt->rhs = cast_stmt;
373 }
374 }
375 bool matching = true;
376 matching = matching && (stmt->lhs->ret_type != PrimitiveType::unknown);
377 matching = matching && (stmt->rhs->ret_type != PrimitiveType::unknown);
378 matching = matching && (stmt->lhs->ret_type == stmt->rhs->ret_type);
379 if (!matching) {
380 error();
381 }
382 if (is_comparison(stmt->op_type)) {
383 stmt->ret_type = make_dt(PrimitiveType::i32);
384 } else {
385 stmt->ret_type = stmt->lhs->ret_type;
386 }
387 }
388
389 void visit(TernaryOpStmt *stmt) override {
390 if (stmt->op_type == TernaryOpType::select) {
391 auto ret_type = promoted_type(stmt->op2->ret_type, stmt->op3->ret_type);
392 TI_ASSERT(stmt->op1->ret_type->is_primitive(PrimitiveTypeID::i32));
393 if (ret_type != stmt->op2->ret_type) {
394 auto cast_stmt = insert_type_cast_before(stmt, stmt->op2, ret_type);
395 stmt->op2 = cast_stmt;
396 }
397 if (ret_type != stmt->op3->ret_type) {
398 auto cast_stmt = insert_type_cast_before(stmt, stmt->op3, ret_type);
399 stmt->op3 = cast_stmt;
400 }
401 stmt->ret_type = ret_type;
402 } else {
403 TI_NOT_IMPLEMENTED
404 }
405 }
406
407 void visit(RangeAssumptionStmt *stmt) override {
408 stmt->ret_type = stmt->input->ret_type;
409 }
410
411 void visit(LoopUniqueStmt *stmt) override {
412 stmt->ret_type = stmt->input->ret_type;
413 }
414
415 void visit(FuncCallStmt *stmt) override {
416 auto *func = stmt->func;
417 TI_ASSERT(func);
418 stmt->ret_type = func->ret_type;
419 }
420
421 void visit(FrontendFuncCallStmt *stmt) override {
422 auto *func = stmt->func;
423 TI_ASSERT(func);
424 stmt->ret_type = func->ret_type;
425 }
426
427 void visit(GetElementStmt *stmt) override {
428 stmt->ret_type =
429 stmt->src->ret_type->as<StructType>()->get_element_type(stmt->index);
430 }
431
432 void visit(ArgLoadStmt *stmt) override {
433 // TODO: Maybe have a type_inference() pass, which takes in the args/rets
434 // defined by the kernel. After that, type_check() pass will purely do
435 // verification, without modifying any types.
436 stmt->ret_type.set_is_pointer(stmt->is_ptr);
437 }
438
439 void visit(ReturnStmt *stmt) override {
440 // TODO: Support stmt->ret_id?
441 }
442
443 void visit(ExternalPtrStmt *stmt) override {
444 /* ExternalPtrStmt may have two different semantics:
445 1. outer indexing to an argloaded external tensor
446 2. outer indexing + inner indexing to get the innermost primitive
447 element of an external tensor
448 We rely on "external_dims" and "indices" to distinguish these two cases.
449 Case #1: external_dims == indices.size(), return TensorType
450 Case #2: external_dims < indices.size(), return PrimitiveType
451 */
452 TI_ASSERT(stmt->base_ptr->is<ArgLoadStmt>());
453 auto arg_load_stmt = stmt->base_ptr->cast<ArgLoadStmt>();
454
455 int external_dims = arg_load_stmt->field_dims_;
456 if (external_dims == stmt->indices.size() || external_dims == -1) {
457 stmt->ret_type = arg_load_stmt->ret_type;
458 } else {
459 stmt->ret_type = arg_load_stmt->ret_type.ptr_removed().get_element_type();
460 }
461
462 stmt->ret_type.set_is_pointer(true);
463 for (int i = 0; i < stmt->indices.size(); i++) {
464 TI_ASSERT(is_integral(stmt->indices[i]->ret_type));
465 if (stmt->indices[i]->ret_type != PrimitiveType::i32) {
466 stmt->indices[i] =
467 insert_type_cast_before(stmt, stmt->indices[i], PrimitiveType::i32);
468 }
469 }
470 }
471
472 void visit(LoopIndexStmt *stmt) override {
473 stmt->ret_type = PrimitiveType::i32;
474 }
475
476 void visit(LoopLinearIndexStmt *stmt) override {
477 stmt->ret_type = PrimitiveType::i32;
478 }
479
480 void visit(BlockCornerIndexStmt *stmt) override {
481 stmt->ret_type = PrimitiveType::i32;
482 }
483
484 void visit(GetRootStmt *stmt) override {
485 stmt->ret_type =
486 TypeFactory::get_instance().get_pointer_type(PrimitiveType::gen);
487 }
488
489 void visit(SNodeLookupStmt *stmt) override {
490 if (stmt->snode->type == SNodeType::quant_array) {
491 auto quant_array_type = stmt->snode->dt;
492 auto element_type =
493 quant_array_type->cast<QuantArrayType>()->get_element_type();
494 auto pointer_type =
495 TypeFactory::get_instance().get_pointer_type(element_type, true);
496 stmt->ret_type = pointer_type;
497 } else {
498 stmt->ret_type =
499 TypeFactory::get_instance().get_pointer_type(PrimitiveType::gen);
500 }
501 }
502
503 void visit(GetChStmt *stmt) override {
504 if (stmt->is_bit_vectorized) {
505 auto physical_type = stmt->output_snode->physical_type;
506 auto ptr_ret_type =
507 TypeFactory::get_instance().get_pointer_type(physical_type);
508 stmt->ret_type = DataType(ptr_ret_type);
509 return;
510 }
511 auto element_type = stmt->output_snode->dt;
512 // For bit_struct SNodes, their component SNodes must have
513 // is_bit_level=true
514 auto pointer_type = TypeFactory::get_instance().get_pointer_type(
515 element_type, stmt->output_snode->is_bit_level);
516 stmt->ret_type = pointer_type;
517 }
518
519 void visit(OffloadedStmt *stmt) override {
520 stmt->all_blocks_accept(this);
521 }
522
523 void visit(LinearizeStmt *stmt) override {
524 stmt->ret_type = PrimitiveType::i32;
525 }
526
527 void visit(IntegerOffsetStmt *stmt) override {
528 stmt->ret_type = PrimitiveType::i32;
529 }
530
531 void visit(AdStackAllocaStmt *stmt) override {
532 stmt->ret_type = stmt->dt;
533 // ret_type stands for its element type.
534 stmt->ret_type.set_is_pointer(false);
535 }
536
537 void visit(AdStackLoadTopStmt *stmt) override {
538 stmt->ret_type = stmt->stack->ret_type;
539 stmt->ret_type.set_is_pointer(false);
540 }
541
542 void visit(AdStackLoadTopAdjStmt *stmt) override {
543 stmt->ret_type = stmt->stack->ret_type;
544 stmt->ret_type.set_is_pointer(false);
545 }
546
547 void visit(AdStackPushStmt *stmt) override {
548 stmt->ret_type = stmt->stack->ret_type;
549 stmt->ret_type.set_is_pointer(false);
550 TI_ASSERT(stmt->ret_type == stmt->v->ret_type);
551 }
552
553 void visit(AdStackPopStmt *stmt) override {
554 stmt->ret_type = stmt->stack->ret_type;
555 stmt->ret_type.set_is_pointer(false);
556 }
557
558 void visit(AdStackAccAdjointStmt *stmt) override {
559 stmt->ret_type = stmt->stack->ret_type;
560 stmt->ret_type.set_is_pointer(false);
561 TI_ASSERT(stmt->ret_type == stmt->v->ret_type);
562 }
563
564 void visit(GlobalTemporaryStmt *stmt) override {
565 stmt->ret_type.set_is_pointer(true);
566 }
567
568 void visit(InternalFuncStmt *stmt) override {
569 // TODO: support return type specification
570 stmt->ret_type = PrimitiveType::i32;
571 }
572
573 void visit(BitStructStoreStmt *stmt) override {
574 // do nothing
575 }
576
577 void visit(ReferenceStmt *stmt) override {
578 stmt->ret_type = stmt->var->ret_type;
579 stmt->ret_type.set_is_pointer(true);
580 }
581
582 void visit(MatrixInitStmt *stmt) override {
583 TI_ASSERT_INFO(stmt->ret_type->is<TensorType>(),
584 "Matrix should have tensor type, got {}",
585 stmt->ret_type->to_string());
586 auto tensor_type = stmt->ret_type->as<TensorType>();
587 auto element_dtype = tensor_type->get_element_type();
588 for (int i = 0; i < stmt->values.size(); ++i) {
589 if (element_dtype != stmt->values[i]->ret_type) {
590 cast(stmt->values[i], element_dtype);
591 }
592 }
593 }
594};
595
596namespace irpass {
597
598void type_check(IRNode *root, const CompileConfig &config) {
599 TI_AUTO_PROF;
600 analysis::check_fields_registered(root);
601 TypeCheck inst(config);
602 root->accept(&inst);
603}
604
605} // namespace irpass
606
607} // namespace taichi::lang
608