1 | // Type checking |
2 | |
3 | #include "taichi/ir/ir.h" |
4 | #include "taichi/ir/statements.h" |
5 | #include "taichi/ir/transforms.h" |
6 | #include "taichi/ir/analysis.h" |
7 | #include "taichi/ir/frontend_ir.h" |
8 | #include "taichi/transforms/utils.h" |
9 | |
10 | namespace taichi::lang { |
11 | |
12 | static_assert( |
13 | sizeof(real) == sizeof(float32), |
14 | "Please build the taichi compiler with single precision (TI_USE_DOUBLE=0)" ); |
15 | |
16 | // Var lookup and Type inference |
17 | class TypeCheck : public IRVisitor { |
18 | private: |
19 | CompileConfig config_; |
20 | |
21 | Type *type_check_store(Stmt *stmt, |
22 | Stmt *dst, |
23 | Stmt *&val, |
24 | const std::string &stmt_name) { |
25 | auto dst_type = dst->ret_type.ptr_removed(); |
26 | if (is_quant(dst_type)) { |
27 | // We force the value type to be the compute_type of the bit pointer. |
28 | // Casting from compute_type to physical_type is handled in codegen. |
29 | dst_type = dst_type->get_compute_type(); |
30 | } |
31 | if (dst_type != val->ret_type) { |
32 | auto promoted = promoted_type(dst_type, val->ret_type); |
33 | if (dst_type != promoted) { |
34 | TI_WARN("[{}] {} may lose precision: {} <- {}\n{}" , stmt->name(), |
35 | stmt_name, dst_type->to_string(), val->ret_data_type_name(), |
36 | stmt->tb); |
37 | } |
38 | val = insert_type_cast_before(stmt, val, dst_type); |
39 | } |
40 | return dst_type; |
41 | } |
42 | |
43 | public: |
44 | explicit TypeCheck(const CompileConfig &config) : config_(config) { |
45 | allow_undefined_visitor = true; |
46 | } |
47 | |
48 | static void mark_as_if_const(Stmt *stmt, DataType t) { |
49 | if (stmt->is<ConstStmt>()) { |
50 | stmt->ret_type = t; |
51 | } |
52 | } |
53 | |
54 | void visit(AllocaStmt *stmt) override { |
55 | // Do nothing. Alloca type is determined by the first LocalStore in IR |
56 | // visiting order, at compile time. |
57 | |
58 | // ret_type stands for its element type. |
59 | } |
60 | |
61 | void visit(IfStmt *if_stmt) override { |
62 | if (if_stmt->true_statements) |
63 | if_stmt->true_statements->accept(this); |
64 | if (if_stmt->false_statements) { |
65 | if_stmt->false_statements->accept(this); |
66 | } |
67 | } |
68 | |
69 | void visit(Block *stmt_list) override { |
70 | std::vector<Stmt *> stmts; |
71 | // Make a copy since type casts may be inserted for type promotion. |
72 | for (auto &stmt : stmt_list->statements) { |
73 | stmts.push_back(stmt.get()); |
74 | } |
75 | for (auto stmt : stmts) |
76 | stmt->accept(this); |
77 | } |
78 | |
79 | void visit(AtomicOpStmt *stmt) override { |
80 | // TODO(type): test_ad_for fails if we assume dest is a pointer type. |
81 | stmt->ret_type = type_check_store( |
82 | stmt, stmt->dest, stmt->val, |
83 | fmt::format("Atomic {}" , atomic_op_type_name(stmt->op_type))); |
84 | } |
85 | |
86 | void visit(LocalLoadStmt *stmt) override { |
87 | TI_ASSERT(stmt->src->is<AllocaStmt>() || stmt->src->is<MatrixPtrStmt>()); |
88 | if (auto ptr_offset_stmt = stmt->src->cast<MatrixPtrStmt>()) { |
89 | TI_ASSERT(ptr_offset_stmt->origin->is<AllocaStmt>() || |
90 | ptr_offset_stmt->origin->is<GlobalTemporaryStmt>()); |
91 | if (auto alloca_stmt = ptr_offset_stmt->origin->cast<AllocaStmt>()) { |
92 | auto lookup = |
93 | DataType( |
94 | alloca_stmt->ret_type->as<TensorType>()->get_element_type()) |
95 | .ptr_removed(); |
96 | stmt->ret_type = lookup; |
97 | } |
98 | if (auto global_temporary_stmt = |
99 | ptr_offset_stmt->origin->cast<GlobalTemporaryStmt>()) { |
100 | auto lookup = DataType(global_temporary_stmt->ret_type->as<TensorType>() |
101 | ->get_element_type()) |
102 | .ptr_removed(); |
103 | stmt->ret_type = lookup; |
104 | } |
105 | } else { |
106 | auto lookup = stmt->src->ret_type; |
107 | stmt->ret_type = lookup; |
108 | } |
109 | } |
110 | |
111 | void visit(LocalStoreStmt *stmt) override { |
112 | if (stmt->dest->ret_type->is_primitive(PrimitiveTypeID::unknown)) { |
113 | // Infer data type for alloca |
114 | stmt->dest->ret_type = stmt->val->ret_type; |
115 | } |
116 | stmt->ret_type = |
117 | type_check_store(stmt, stmt->dest, stmt->val, "Local store" ); |
118 | } |
119 | |
120 | void visit(GlobalLoadStmt *stmt) override { |
121 | auto pointee_type = stmt->src->ret_type.ptr_removed(); |
122 | stmt->ret_type = pointee_type->get_compute_type(); |
123 | } |
124 | |
125 | void visit(SNodeOpStmt *stmt) override { |
126 | if (stmt->op_type == SNodeOpType::get_addr) { |
127 | stmt->ret_type = PrimitiveType::u64; |
128 | } else if (stmt->op_type == SNodeOpType::allocate) { |
129 | stmt->ret_type = PrimitiveType::gen; |
130 | stmt->ret_type.set_is_pointer(true); |
131 | } else { |
132 | stmt->ret_type = PrimitiveType::i32; |
133 | } |
134 | } |
135 | |
136 | void visit(ExternalTensorShapeAlongAxisStmt *stmt) override { |
137 | stmt->ret_type = PrimitiveType::i32; |
138 | } |
139 | |
140 | void visit(GlobalPtrStmt *stmt) override { |
141 | if (stmt->is_bit_vectorized) { |
142 | return; |
143 | } |
144 | stmt->ret_type.set_is_pointer(true); |
145 | if (stmt->snode) { |
146 | stmt->ret_type = |
147 | TypeFactory::get_instance().get_pointer_type(stmt->snode->dt); |
148 | } else |
149 | TI_WARN("[{}] Type inference failed: snode is nullptr.\n{}" , stmt->name(), |
150 | stmt->tb); |
151 | auto check_indices = [&](SNode *snode) { |
152 | if (snode->num_active_indices != stmt->indices.size()) { |
153 | TI_ERROR("[{}] {} has {} indices. Indexed with {}." , stmt->name(), |
154 | snode->node_type_name, snode->num_active_indices, |
155 | stmt->indices.size()); |
156 | } |
157 | }; |
158 | check_indices(stmt->is_cell_access ? stmt->snode : stmt->snode->parent); |
159 | for (int i = 0; i < stmt->indices.size(); i++) { |
160 | if (!stmt->indices[i]->ret_type->is_primitive(PrimitiveTypeID::i32)) { |
161 | TI_WARN( |
162 | "[{}] Field index {} not int32, casting into int32 " |
163 | "implicitly\n{}" , |
164 | stmt->name(), i, stmt->tb); |
165 | stmt->indices[i] = |
166 | insert_type_cast_before(stmt, stmt->indices[i], PrimitiveType::i32); |
167 | } |
168 | } |
169 | } |
170 | |
171 | void visit(MatrixPtrStmt *stmt) override { |
172 | TI_ASSERT(stmt->offset->ret_type->is_primitive(PrimitiveTypeID::i32)); |
173 | stmt->ret_type.set_is_pointer(true); |
174 | } |
175 | |
176 | void visit(GlobalStoreStmt *stmt) override { |
177 | type_check_store(stmt, stmt->dest, stmt->val, "Global store" ); |
178 | } |
179 | |
180 | void visit(RangeForStmt *stmt) override { |
181 | mark_as_if_const(stmt->begin, PrimitiveType::i32); |
182 | mark_as_if_const(stmt->end, PrimitiveType::i32); |
183 | stmt->body->accept(this); |
184 | } |
185 | |
186 | void visit(StructForStmt *stmt) override { |
187 | stmt->body->accept(this); |
188 | } |
189 | |
190 | void visit(MeshForStmt *stmt) override { |
191 | stmt->body->accept(this); |
192 | } |
193 | |
194 | void visit(WhileStmt *stmt) override { |
195 | stmt->body->accept(this); |
196 | } |
197 | |
198 | void visit(UnaryOpStmt *stmt) override { |
199 | auto operand_type = stmt->operand->ret_type; |
200 | stmt->ret_type = operand_type; |
201 | if (stmt->is_cast()) { |
202 | stmt->ret_type = stmt->cast_type; |
203 | if (operand_type->is<TensorType>() && |
204 | stmt->cast_type->is<PrimitiveType>()) { |
205 | auto ret_tensor_type = operand_type->as<TensorType>(); |
206 | auto tensor_shape = ret_tensor_type->get_shape(); |
207 | stmt->ret_type = TypeFactory::get_instance().create_tensor_type( |
208 | tensor_shape, stmt->cast_type); |
209 | } |
210 | } |
211 | |
212 | DataType primitive_dtype = stmt->operand->ret_type.get_element_type(); |
213 | if (!is_real(primitive_dtype)) { |
214 | if (stmt->op_type == UnaryOpType::sqrt || |
215 | stmt->op_type == UnaryOpType::exp || |
216 | stmt->op_type == UnaryOpType::log) { |
217 | DataType target_dtype = config_.default_fp; |
218 | if (stmt->operand->ret_type->is<TensorType>()) { |
219 | target_dtype = TypeFactory::get_instance().create_tensor_type( |
220 | stmt->operand->ret_type->as<TensorType>()->get_shape(), |
221 | target_dtype); |
222 | } |
223 | |
224 | cast(stmt->operand, target_dtype); |
225 | stmt->ret_type = target_dtype; |
226 | } |
227 | } |
228 | } |
229 | |
230 | Stmt *insert_type_cast_before(Stmt *anchor, |
231 | Stmt *input, |
232 | DataType output_type) { |
233 | auto &&cast_stmt = |
234 | Stmt::make_typed<UnaryOpStmt>(UnaryOpType::cast_value, input); |
235 | cast_stmt->cast_type = output_type; |
236 | cast_stmt->accept(this); |
237 | auto stmt = cast_stmt.get(); |
238 | anchor->insert_before_me(std::move(cast_stmt)); |
239 | return stmt; |
240 | } |
241 | |
242 | Stmt *insert_type_cast_after(Stmt *anchor, |
243 | Stmt *input, |
244 | DataType output_type) { |
245 | auto &&cast_stmt = |
246 | Stmt::make_typed<UnaryOpStmt>(UnaryOpType::cast_value, input); |
247 | cast_stmt->cast_type = output_type; |
248 | cast_stmt->accept(this); |
249 | auto stmt = cast_stmt.get(); |
250 | anchor->insert_after_me(std::move(cast_stmt)); |
251 | return stmt; |
252 | } |
253 | |
254 | void insert_shift_op_assertion_before(Stmt *stmt, Stmt *lhs, Stmt *rhs) { |
255 | int rhs_limit = data_type_bits(lhs->ret_type); |
256 | auto const_stmt = |
257 | Stmt::make<ConstStmt>(TypedConstant(rhs->ret_type, rhs_limit)); |
258 | auto cond_stmt = |
259 | Stmt::make<BinaryOpStmt>(BinaryOpType::cmp_le, rhs, const_stmt.get()); |
260 | |
261 | std::string msg = |
262 | "Detected overflow for bit_shift_op with rhs = %d, exceeding limit of " |
263 | "%d." ; |
264 | msg += "\n" + stmt->tb; |
265 | std::vector<Stmt *> args = {rhs, const_stmt.get()}; |
266 | auto assert_stmt = |
267 | Stmt::make<AssertStmt>(cond_stmt.get(), msg, std::move(args)); |
268 | |
269 | const_stmt->accept(this); |
270 | cond_stmt->accept(this); |
271 | assert_stmt->accept(this); |
272 | |
273 | stmt->insert_before_me(std::move(const_stmt)); |
274 | stmt->insert_before_me(std::move(cond_stmt)); |
275 | stmt->insert_before_me(std::move(assert_stmt)); |
276 | } |
277 | |
278 | void cast(Stmt *&val, DataType dt) { |
279 | if (val->ret_type == dt) |
280 | return; |
281 | |
282 | auto cast_stmt = insert_type_cast_after(val, val, dt); |
283 | val = cast_stmt; |
284 | } |
285 | |
286 | void visit(BinaryOpStmt *stmt) override { |
287 | auto error = [&](std::string = "" ) { |
288 | if (comment == "" ) { |
289 | TI_WARN("[{}] Type mismatch (left = {}, right = {}, stmt_id = {})\n{}" , |
290 | stmt->name(), stmt->lhs->ret_data_type_name(), |
291 | stmt->rhs->ret_data_type_name(), stmt->id, stmt->tb); |
292 | } else { |
293 | TI_WARN("[{}] {}\n{}" , stmt->name(), comment, stmt->tb); |
294 | } |
295 | TI_WARN("Compilation stopped due to type mismatch." ); |
296 | throw std::runtime_error("Binary operator type mismatch" ); |
297 | }; |
298 | if (stmt->lhs->ret_type->is_primitive(PrimitiveTypeID::unknown) && |
299 | stmt->rhs->ret_type->is_primitive(PrimitiveTypeID::unknown)) |
300 | error(); |
301 | if (stmt->op_type == BinaryOpType::pow && |
302 | (is_integral(stmt->rhs->ret_type.get_element_type()))) { |
303 | stmt->ret_type = stmt->lhs->ret_type; |
304 | return; |
305 | } |
306 | |
307 | auto make_dt = [stmt](DataType dt) { |
308 | if (auto tensor_ty = stmt->lhs->ret_type->cast<TensorType>()) { |
309 | return TypeFactory::create_tensor_type(tensor_ty->get_shape(), dt); |
310 | } else { |
311 | return dt; |
312 | } |
313 | }; |
314 | |
315 | // lower truediv into div |
316 | |
317 | if (stmt->op_type == BinaryOpType::truediv) { |
318 | auto default_fp = config_.default_fp; |
319 | if (!is_real(stmt->lhs->ret_type.get_element_type())) { |
320 | cast(stmt->lhs, make_dt(default_fp)); |
321 | } |
322 | if (!is_real(stmt->rhs->ret_type.get_element_type())) { |
323 | cast(stmt->rhs, make_dt(default_fp)); |
324 | } |
325 | stmt->op_type = BinaryOpType::div; |
326 | } |
327 | |
328 | // Some backends such as vulkan doesn't support fp64 |
329 | // Always promote to fp32 unless necessary |
330 | if (stmt->op_type == BinaryOpType::atan2) { |
331 | if (stmt->rhs->ret_type == PrimitiveType::f64 || |
332 | stmt->lhs->ret_type == PrimitiveType::f64) { |
333 | stmt->ret_type = make_dt(PrimitiveType::f64); |
334 | cast(stmt->rhs, make_dt(PrimitiveType::f64)); |
335 | cast(stmt->lhs, make_dt(PrimitiveType::f64)); |
336 | } else { |
337 | stmt->ret_type = make_dt(PrimitiveType::f32); |
338 | cast(stmt->rhs, make_dt(PrimitiveType::f32)); |
339 | cast(stmt->lhs, make_dt(PrimitiveType::f32)); |
340 | } |
341 | } |
342 | |
343 | if (stmt->lhs->ret_type != stmt->rhs->ret_type) { |
344 | DataType ret_type; |
345 | if (is_shift_op(stmt->op_type)) { |
346 | // shift_ops does not follow the same type promotion rule as numerical |
347 | // ops numerical ops: u8 + i32 = i32 shift_ops: u8 << i32 = u8 |
348 | // (return dtype follows that of the lhs) |
349 | // |
350 | // In the above example, while truncating rhs(i32) to u8 risks an |
351 | // overflow, the runtime value of rhs is very likely less than 8 |
352 | // (otherwise meaningless). Nevertheless, we insert an AssertStmt here |
353 | // to warn user of this potential overflow. |
354 | ret_type = stmt->lhs->ret_type; |
355 | |
356 | // Insert AssertStmt |
357 | if (config_.debug) { |
358 | insert_shift_op_assertion_before(stmt, stmt->lhs, stmt->rhs); |
359 | } |
360 | } else { |
361 | ret_type = promoted_type(stmt->lhs->ret_type, stmt->rhs->ret_type); |
362 | } |
363 | |
364 | if (ret_type != stmt->lhs->ret_type) { |
365 | // promote lhs |
366 | auto cast_stmt = insert_type_cast_before(stmt, stmt->lhs, ret_type); |
367 | stmt->lhs = cast_stmt; |
368 | } |
369 | if (ret_type != stmt->rhs->ret_type) { |
370 | // promote rhs |
371 | auto cast_stmt = insert_type_cast_before(stmt, stmt->rhs, ret_type); |
372 | stmt->rhs = cast_stmt; |
373 | } |
374 | } |
375 | bool matching = true; |
376 | matching = matching && (stmt->lhs->ret_type != PrimitiveType::unknown); |
377 | matching = matching && (stmt->rhs->ret_type != PrimitiveType::unknown); |
378 | matching = matching && (stmt->lhs->ret_type == stmt->rhs->ret_type); |
379 | if (!matching) { |
380 | error(); |
381 | } |
382 | if (is_comparison(stmt->op_type)) { |
383 | stmt->ret_type = make_dt(PrimitiveType::i32); |
384 | } else { |
385 | stmt->ret_type = stmt->lhs->ret_type; |
386 | } |
387 | } |
388 | |
389 | void visit(TernaryOpStmt *stmt) override { |
390 | if (stmt->op_type == TernaryOpType::select) { |
391 | auto ret_type = promoted_type(stmt->op2->ret_type, stmt->op3->ret_type); |
392 | TI_ASSERT(stmt->op1->ret_type->is_primitive(PrimitiveTypeID::i32)); |
393 | if (ret_type != stmt->op2->ret_type) { |
394 | auto cast_stmt = insert_type_cast_before(stmt, stmt->op2, ret_type); |
395 | stmt->op2 = cast_stmt; |
396 | } |
397 | if (ret_type != stmt->op3->ret_type) { |
398 | auto cast_stmt = insert_type_cast_before(stmt, stmt->op3, ret_type); |
399 | stmt->op3 = cast_stmt; |
400 | } |
401 | stmt->ret_type = ret_type; |
402 | } else { |
403 | TI_NOT_IMPLEMENTED |
404 | } |
405 | } |
406 | |
407 | void visit(RangeAssumptionStmt *stmt) override { |
408 | stmt->ret_type = stmt->input->ret_type; |
409 | } |
410 | |
411 | void visit(LoopUniqueStmt *stmt) override { |
412 | stmt->ret_type = stmt->input->ret_type; |
413 | } |
414 | |
415 | void visit(FuncCallStmt *stmt) override { |
416 | auto *func = stmt->func; |
417 | TI_ASSERT(func); |
418 | stmt->ret_type = func->ret_type; |
419 | } |
420 | |
421 | void visit(FrontendFuncCallStmt *stmt) override { |
422 | auto *func = stmt->func; |
423 | TI_ASSERT(func); |
424 | stmt->ret_type = func->ret_type; |
425 | } |
426 | |
427 | void visit(GetElementStmt *stmt) override { |
428 | stmt->ret_type = |
429 | stmt->src->ret_type->as<StructType>()->get_element_type(stmt->index); |
430 | } |
431 | |
432 | void visit(ArgLoadStmt *stmt) override { |
433 | // TODO: Maybe have a type_inference() pass, which takes in the args/rets |
434 | // defined by the kernel. After that, type_check() pass will purely do |
435 | // verification, without modifying any types. |
436 | stmt->ret_type.set_is_pointer(stmt->is_ptr); |
437 | } |
438 | |
439 | void visit(ReturnStmt *stmt) override { |
440 | // TODO: Support stmt->ret_id? |
441 | } |
442 | |
443 | void visit(ExternalPtrStmt *stmt) override { |
444 | /* ExternalPtrStmt may have two different semantics: |
445 | 1. outer indexing to an argloaded external tensor |
446 | 2. outer indexing + inner indexing to get the innermost primitive |
447 | element of an external tensor |
448 | We rely on "external_dims" and "indices" to distinguish these two cases. |
449 | Case #1: external_dims == indices.size(), return TensorType |
450 | Case #2: external_dims < indices.size(), return PrimitiveType |
451 | */ |
452 | TI_ASSERT(stmt->base_ptr->is<ArgLoadStmt>()); |
453 | auto arg_load_stmt = stmt->base_ptr->cast<ArgLoadStmt>(); |
454 | |
455 | int external_dims = arg_load_stmt->field_dims_; |
456 | if (external_dims == stmt->indices.size() || external_dims == -1) { |
457 | stmt->ret_type = arg_load_stmt->ret_type; |
458 | } else { |
459 | stmt->ret_type = arg_load_stmt->ret_type.ptr_removed().get_element_type(); |
460 | } |
461 | |
462 | stmt->ret_type.set_is_pointer(true); |
463 | for (int i = 0; i < stmt->indices.size(); i++) { |
464 | TI_ASSERT(is_integral(stmt->indices[i]->ret_type)); |
465 | if (stmt->indices[i]->ret_type != PrimitiveType::i32) { |
466 | stmt->indices[i] = |
467 | insert_type_cast_before(stmt, stmt->indices[i], PrimitiveType::i32); |
468 | } |
469 | } |
470 | } |
471 | |
472 | void visit(LoopIndexStmt *stmt) override { |
473 | stmt->ret_type = PrimitiveType::i32; |
474 | } |
475 | |
476 | void visit(LoopLinearIndexStmt *stmt) override { |
477 | stmt->ret_type = PrimitiveType::i32; |
478 | } |
479 | |
480 | void visit(BlockCornerIndexStmt *stmt) override { |
481 | stmt->ret_type = PrimitiveType::i32; |
482 | } |
483 | |
484 | void visit(GetRootStmt *stmt) override { |
485 | stmt->ret_type = |
486 | TypeFactory::get_instance().get_pointer_type(PrimitiveType::gen); |
487 | } |
488 | |
489 | void visit(SNodeLookupStmt *stmt) override { |
490 | if (stmt->snode->type == SNodeType::quant_array) { |
491 | auto quant_array_type = stmt->snode->dt; |
492 | auto element_type = |
493 | quant_array_type->cast<QuantArrayType>()->get_element_type(); |
494 | auto pointer_type = |
495 | TypeFactory::get_instance().get_pointer_type(element_type, true); |
496 | stmt->ret_type = pointer_type; |
497 | } else { |
498 | stmt->ret_type = |
499 | TypeFactory::get_instance().get_pointer_type(PrimitiveType::gen); |
500 | } |
501 | } |
502 | |
503 | void visit(GetChStmt *stmt) override { |
504 | if (stmt->is_bit_vectorized) { |
505 | auto physical_type = stmt->output_snode->physical_type; |
506 | auto ptr_ret_type = |
507 | TypeFactory::get_instance().get_pointer_type(physical_type); |
508 | stmt->ret_type = DataType(ptr_ret_type); |
509 | return; |
510 | } |
511 | auto element_type = stmt->output_snode->dt; |
512 | // For bit_struct SNodes, their component SNodes must have |
513 | // is_bit_level=true |
514 | auto pointer_type = TypeFactory::get_instance().get_pointer_type( |
515 | element_type, stmt->output_snode->is_bit_level); |
516 | stmt->ret_type = pointer_type; |
517 | } |
518 | |
519 | void visit(OffloadedStmt *stmt) override { |
520 | stmt->all_blocks_accept(this); |
521 | } |
522 | |
523 | void visit(LinearizeStmt *stmt) override { |
524 | stmt->ret_type = PrimitiveType::i32; |
525 | } |
526 | |
527 | void visit(IntegerOffsetStmt *stmt) override { |
528 | stmt->ret_type = PrimitiveType::i32; |
529 | } |
530 | |
531 | void visit(AdStackAllocaStmt *stmt) override { |
532 | stmt->ret_type = stmt->dt; |
533 | // ret_type stands for its element type. |
534 | stmt->ret_type.set_is_pointer(false); |
535 | } |
536 | |
537 | void visit(AdStackLoadTopStmt *stmt) override { |
538 | stmt->ret_type = stmt->stack->ret_type; |
539 | stmt->ret_type.set_is_pointer(false); |
540 | } |
541 | |
542 | void visit(AdStackLoadTopAdjStmt *stmt) override { |
543 | stmt->ret_type = stmt->stack->ret_type; |
544 | stmt->ret_type.set_is_pointer(false); |
545 | } |
546 | |
547 | void visit(AdStackPushStmt *stmt) override { |
548 | stmt->ret_type = stmt->stack->ret_type; |
549 | stmt->ret_type.set_is_pointer(false); |
550 | TI_ASSERT(stmt->ret_type == stmt->v->ret_type); |
551 | } |
552 | |
553 | void visit(AdStackPopStmt *stmt) override { |
554 | stmt->ret_type = stmt->stack->ret_type; |
555 | stmt->ret_type.set_is_pointer(false); |
556 | } |
557 | |
558 | void visit(AdStackAccAdjointStmt *stmt) override { |
559 | stmt->ret_type = stmt->stack->ret_type; |
560 | stmt->ret_type.set_is_pointer(false); |
561 | TI_ASSERT(stmt->ret_type == stmt->v->ret_type); |
562 | } |
563 | |
564 | void visit(GlobalTemporaryStmt *stmt) override { |
565 | stmt->ret_type.set_is_pointer(true); |
566 | } |
567 | |
568 | void visit(InternalFuncStmt *stmt) override { |
569 | // TODO: support return type specification |
570 | stmt->ret_type = PrimitiveType::i32; |
571 | } |
572 | |
573 | void visit(BitStructStoreStmt *stmt) override { |
574 | // do nothing |
575 | } |
576 | |
577 | void visit(ReferenceStmt *stmt) override { |
578 | stmt->ret_type = stmt->var->ret_type; |
579 | stmt->ret_type.set_is_pointer(true); |
580 | } |
581 | |
582 | void visit(MatrixInitStmt *stmt) override { |
583 | TI_ASSERT_INFO(stmt->ret_type->is<TensorType>(), |
584 | "Matrix should have tensor type, got {}" , |
585 | stmt->ret_type->to_string()); |
586 | auto tensor_type = stmt->ret_type->as<TensorType>(); |
587 | auto element_dtype = tensor_type->get_element_type(); |
588 | for (int i = 0; i < stmt->values.size(); ++i) { |
589 | if (element_dtype != stmt->values[i]->ret_type) { |
590 | cast(stmt->values[i], element_dtype); |
591 | } |
592 | } |
593 | } |
594 | }; |
595 | |
596 | namespace irpass { |
597 | |
598 | void type_check(IRNode *root, const CompileConfig &config) { |
599 | TI_AUTO_PROF; |
600 | analysis::check_fields_registered(root); |
601 | TypeCheck inst(config); |
602 | root->accept(&inst); |
603 | } |
604 | |
605 | } // namespace irpass |
606 | |
607 | } // namespace taichi::lang |
608 | |