1 | #include "taichi/ir/frontend_ir.h" |
2 | |
3 | #include "taichi/ir/expression_printer.h" |
4 | #include "taichi/ir/statements.h" |
5 | #include "taichi/program/program.h" |
6 | #include "taichi/common/exceptions.h" |
7 | |
8 | #include <numeric> |
9 | |
10 | namespace taichi::lang { |
11 | |
12 | #define TI_ASSERT_TYPE_CHECKED(x) \ |
13 | TI_ASSERT_INFO(x->ret_type != PrimitiveType::unknown, \ |
14 | "[{}] was not type-checked", \ |
15 | ExpressionHumanFriendlyPrinter::expr_to_string(x)) |
16 | |
17 | static bool is_primitive_or_tensor_type(DataType &type) { |
18 | return type->is<PrimitiveType>() || type->is<TensorType>(); |
19 | } |
20 | |
21 | FrontendSNodeOpStmt::FrontendSNodeOpStmt(SNodeOpType op_type, |
22 | SNode *snode, |
23 | const ExprGroup &indices, |
24 | const Expr &val) |
25 | : op_type(op_type), snode(snode), indices(indices), val(val) { |
26 | if (val.expr != nullptr) { |
27 | TI_ASSERT(op_type == SNodeOpType::append); |
28 | } else { |
29 | TI_ASSERT(op_type != SNodeOpType::append); |
30 | } |
31 | } |
32 | |
33 | FrontendReturnStmt::FrontendReturnStmt(const ExprGroup &group) : values(group) { |
34 | } |
35 | |
36 | FrontendAssignStmt::FrontendAssignStmt(const Expr &lhs, const Expr &rhs) |
37 | : lhs(lhs), rhs(rhs) { |
38 | TI_ASSERT(lhs->is_lvalue()); |
39 | if (lhs.is<IdExpression>() && lhs->ret_type == PrimitiveType::unknown) { |
40 | lhs.expr->ret_type = rhs->ret_type; |
41 | } |
42 | } |
43 | |
44 | FrontendForStmt::FrontendForStmt(const ExprGroup &loop_vars, |
45 | SNode *snode, |
46 | Arch arch, |
47 | const ForLoopConfig &config) |
48 | : snode(snode) { |
49 | init_config(arch, config); |
50 | init_loop_vars(loop_vars); |
51 | } |
52 | |
53 | FrontendForStmt::FrontendForStmt(const ExprGroup &loop_vars, |
54 | const Expr &external_tensor, |
55 | Arch arch, |
56 | const ForLoopConfig &config) |
57 | : external_tensor(external_tensor) { |
58 | init_config(arch, config); |
59 | init_loop_vars(loop_vars); |
60 | } |
61 | |
62 | FrontendForStmt::FrontendForStmt(const ExprGroup &loop_vars, |
63 | const mesh::MeshPtr &mesh, |
64 | const mesh::MeshElementType &element_type, |
65 | Arch arch, |
66 | const ForLoopConfig &config) |
67 | : mesh(mesh.ptr.get()), element_type(element_type) { |
68 | init_config(arch, config); |
69 | init_loop_vars(loop_vars); |
70 | } |
71 | |
72 | FrontendForStmt::FrontendForStmt(const Expr &loop_var, |
73 | const Expr &begin, |
74 | const Expr &end, |
75 | Arch arch, |
76 | const ForLoopConfig &config) |
77 | : begin(begin), end(end) { |
78 | init_config(arch, config); |
79 | add_loop_var(loop_var); |
80 | } |
81 | |
82 | void FrontendForStmt::init_config(Arch arch, const ForLoopConfig &config) { |
83 | is_bit_vectorized = config.is_bit_vectorized; |
84 | strictly_serialized = config.strictly_serialized; |
85 | mem_access_opt = config.mem_access_opt; |
86 | block_dim = config.block_dim; |
87 | if (arch == Arch::cuda || arch == Arch::amdgpu) { |
88 | num_cpu_threads = 1; |
89 | TI_ASSERT(block_dim <= taichi_max_gpu_block_dim); |
90 | } else { // cpu |
91 | if (config.num_cpu_threads == 0) { |
92 | num_cpu_threads = std::thread::hardware_concurrency(); |
93 | } else { |
94 | num_cpu_threads = config.num_cpu_threads; |
95 | } |
96 | } |
97 | } |
98 | |
99 | void FrontendForStmt::init_loop_vars(const ExprGroup &loop_vars) { |
100 | loop_var_ids.reserve(loop_vars.size()); |
101 | for (int i = 0; i < (int)loop_vars.size(); i++) { |
102 | add_loop_var(loop_vars[i]); |
103 | } |
104 | } |
105 | |
106 | void FrontendForStmt::add_loop_var(const Expr &loop_var) { |
107 | loop_var_ids.push_back(loop_var.cast<IdExpression>()->id); |
108 | loop_var.expr->ret_type = PrimitiveType::i32; |
109 | } |
110 | |
111 | void ArgLoadExpression::type_check(const CompileConfig *) { |
112 | TI_ASSERT_INFO(dt->is<PrimitiveType>() && dt != PrimitiveType::unknown, |
113 | "Invalid dt [{}] for ArgLoadExpression" , dt->to_string()); |
114 | ret_type = dt; |
115 | } |
116 | |
117 | void ArgLoadExpression::flatten(FlattenContext *ctx) { |
118 | auto arg_load = std::make_unique<ArgLoadStmt>(arg_id, dt, is_ptr); |
119 | ctx->push_back(std::move(arg_load)); |
120 | stmt = ctx->back_stmt(); |
121 | } |
122 | |
123 | void TexturePtrExpression::type_check(const CompileConfig *config) { |
124 | } |
125 | |
126 | void TexturePtrExpression::flatten(FlattenContext *ctx) { |
127 | ctx->push_back<ArgLoadStmt>(arg_id, PrimitiveType::f32, true); |
128 | ctx->push_back<TexturePtrStmt>(ctx->back_stmt(), num_dims, is_storage, |
129 | num_channels, channel_format, lod); |
130 | stmt = ctx->back_stmt(); |
131 | } |
132 | |
133 | void RandExpression::type_check(const CompileConfig *) { |
134 | TI_ASSERT_INFO(dt->is<PrimitiveType>() && dt != PrimitiveType::unknown, |
135 | "Invalid dt [{}] for RandExpression" , dt->to_string()); |
136 | ret_type = dt; |
137 | } |
138 | |
139 | void RandExpression::flatten(FlattenContext *ctx) { |
140 | auto ran = std::make_unique<RandStmt>(dt); |
141 | ctx->push_back(std::move(ran)); |
142 | stmt = ctx->back_stmt(); |
143 | } |
144 | |
145 | void UnaryOpExpression::type_check(const CompileConfig *config) { |
146 | TI_ASSERT_TYPE_CHECKED(operand); |
147 | |
148 | TI_ASSERT(config != nullptr); |
149 | /* |
150 | Dtype inference for both TensorType and PrimitiveType are essentially |
151 | the same. Therefore we extract the primitive type to perform the type |
152 | inference, and then reconstruct the TensorType once neccessary. |
153 | */ |
154 | |
155 | auto operand_primitive_type = operand->ret_type.get_element_type(); |
156 | auto ret_primitive_type = ret_type; |
157 | |
158 | if (!operand_primitive_type->is<PrimitiveType>()) { |
159 | throw TaichiTypeError(fmt::format( |
160 | "unsupported operand type(s) for '{}': '{}'" , unary_op_type_name(type), |
161 | operand_primitive_type->to_string())); |
162 | } |
163 | |
164 | if ((type == UnaryOpType::round || type == UnaryOpType::floor || |
165 | type == UnaryOpType::ceil || is_trigonometric(type)) && |
166 | !is_real(operand_primitive_type)) |
167 | throw TaichiTypeError(fmt::format( |
168 | "'{}' takes real inputs only, however '{}' is provided" , |
169 | unary_op_type_name(type), operand_primitive_type->to_string())); |
170 | |
171 | if ((type == UnaryOpType::sqrt || type == UnaryOpType::exp || |
172 | type == UnaryOpType::log) && |
173 | !is_real(operand_primitive_type)) { |
174 | ret_primitive_type = config->default_fp; |
175 | } else { |
176 | ret_primitive_type = is_cast() ? cast_type : operand_primitive_type; |
177 | } |
178 | |
179 | if (operand->ret_type->is<TensorType>()) { |
180 | ret_type = taichi::lang::TypeFactory::get_instance().get_tensor_type( |
181 | operand->ret_type.get_shape(), ret_primitive_type); |
182 | } else { |
183 | TI_ASSERT(operand->ret_type->is<PrimitiveType>()); |
184 | ret_type = ret_primitive_type; |
185 | } |
186 | } |
187 | |
188 | bool UnaryOpExpression::is_cast() const { |
189 | return unary_op_is_cast(type); |
190 | } |
191 | |
192 | void UnaryOpExpression::flatten(FlattenContext *ctx) { |
193 | auto operand_stmt = flatten_rvalue(operand, ctx); |
194 | auto unary = std::make_unique<UnaryOpStmt>(type, operand_stmt); |
195 | if (is_cast()) { |
196 | unary->cast_type = cast_type; |
197 | } |
198 | stmt = unary.get(); |
199 | stmt->tb = tb; |
200 | stmt->ret_type = ret_type; |
201 | ctx->push_back(std::move(unary)); |
202 | } |
203 | |
204 | Expr to_broadcast_tensor(const Expr &elt, const DataType &dt) { |
205 | if (!elt->ret_type->is<TensorType>() && !dt->is<TensorType>()) |
206 | return elt; |
207 | |
208 | if (elt->ret_type->is<TensorType>() && dt->is<TensorType>()) { |
209 | // Only tensor shape will be checked here, since the dtype will |
210 | // be promoted later at irpass::type_check() |
211 | if (elt->ret_type.get_shape() != dt.get_shape()) { |
212 | TI_ERROR("Cannot broadcast tensor to tensor" ); |
213 | } else { |
214 | return elt; |
215 | } |
216 | } |
217 | |
218 | auto tensor_type = dt->as<TensorType>(); |
219 | auto elt_type = tensor_type->get_element_type(); |
220 | TI_ASSERT_INFO(elt_type->is<PrimitiveType>(), |
221 | "Only primitive types are supported in Tensors, got {}" , |
222 | elt_type->to_string()); |
223 | std::vector<Expr> broadcast_values(tensor_type->get_num_elements(), elt); |
224 | auto matrix_expr = Expr::make<MatrixExpression>( |
225 | broadcast_values, tensor_type->get_shape(), elt->ret_type); |
226 | matrix_expr->type_check(nullptr); |
227 | return matrix_expr; |
228 | } |
229 | |
230 | std::tuple<Expr, Expr> unify_binop_operands(const Expr &e1, const Expr &e2) { |
231 | if (e1->ret_type->is<PrimitiveType>() && e2->ret_type->is<TensorType>()) { |
232 | return std::tuple(to_broadcast_tensor(e1, e2->ret_type), e2); |
233 | } else if (e1->ret_type->is<TensorType>() && |
234 | e2->ret_type->is<PrimitiveType>()) { |
235 | return std::tuple(e1, to_broadcast_tensor(e2, e1->ret_type)); |
236 | } else { |
237 | return std::tuple(e1, e2); |
238 | } |
239 | } |
240 | |
241 | void BinaryOpExpression::type_check(const CompileConfig *config) { |
242 | TI_ASSERT_TYPE_CHECKED(lhs); |
243 | TI_ASSERT_TYPE_CHECKED(rhs); |
244 | auto lhs_type = lhs->ret_type; |
245 | auto rhs_type = rhs->ret_type; |
246 | auto error = [&]() { |
247 | throw TaichiTypeError( |
248 | fmt::format("unsupported operand type(s) for '{}': '{}' and '{}'" , |
249 | binary_op_type_symbol(type), lhs->ret_type->to_string(), |
250 | rhs->ret_type->to_string())); |
251 | }; |
252 | |
253 | if (!is_primitive_or_tensor_type(lhs_type) || |
254 | !is_primitive_or_tensor_type(rhs_type)) { |
255 | error(); |
256 | } |
257 | |
258 | if ((lhs_type->is<PrimitiveType>() && rhs_type->is<TensorType>()) || |
259 | (lhs_type->is<TensorType>() && rhs_type->is<PrimitiveType>())) { |
260 | // convert Tensor/Scalar | Scalar/Tensor operations to broadcasting |
261 | auto [unified_l, unified_r] = unify_binop_operands(lhs, rhs); |
262 | lhs = unified_l; |
263 | rhs = unified_r; |
264 | if (lhs->ret_type == PrimitiveType::unknown) |
265 | lhs.type_check(config); |
266 | if (rhs->ret_type == PrimitiveType::unknown) |
267 | rhs.type_check(config); |
268 | TI_ASSERT(lhs->ret_type->is<TensorType>()); |
269 | TI_ASSERT(rhs->ret_type->is<TensorType>()); |
270 | lhs_type = lhs->ret_type; |
271 | rhs_type = rhs->ret_type; |
272 | } |
273 | |
274 | bool is_tensor_op = false; |
275 | |
276 | if (lhs_type->is<TensorType>()) { |
277 | is_tensor_op = true; |
278 | auto rhs_tensor_type = rhs_type->cast<TensorType>(); |
279 | if (rhs_tensor_type->get_shape() != |
280 | lhs_type->cast<TensorType>()->get_shape()) |
281 | // current assume element-wise binary op |
282 | error(); |
283 | } |
284 | |
285 | auto make_dt = [&is_tensor_op, this](DataType dt) { |
286 | if (is_tensor_op) { |
287 | return TypeFactory::create_tensor_type( |
288 | this->lhs->ret_type->cast<TensorType>()->get_shape(), dt); |
289 | } else { |
290 | return dt; |
291 | } |
292 | }; |
293 | |
294 | if (binary_is_bitwise(type) && (!is_integral(lhs_type.get_element_type()) || |
295 | !is_integral(rhs_type.get_element_type()))) |
296 | error(); |
297 | if (binary_is_logical(type) && |
298 | (is_tensor_op || lhs_type != PrimitiveType::i32 || |
299 | rhs_type != PrimitiveType::i32)) |
300 | error(); |
301 | if (is_comparison(type) || binary_is_logical(type)) { |
302 | ret_type = make_dt(PrimitiveType::i32); |
303 | return; |
304 | } |
305 | if (is_shift_op(type) || |
306 | (type == BinaryOpType::pow && is_integral(rhs_type))) { |
307 | ret_type = lhs_type; |
308 | return; |
309 | } |
310 | |
311 | // Some backends such as vulkan doesn't support fp64 |
312 | // Try not promoting to fp64 unless necessary |
313 | if (type == BinaryOpType::atan2) { |
314 | if (lhs_type == PrimitiveType::f64 || rhs_type == PrimitiveType::f64) { |
315 | ret_type = make_dt(PrimitiveType::f64); |
316 | } else { |
317 | ret_type = make_dt(PrimitiveType::f32); |
318 | } |
319 | return; |
320 | } |
321 | |
322 | if (type == BinaryOpType::truediv) { |
323 | auto default_fp = config->default_fp; |
324 | if (!is_real(lhs_type.get_element_type())) { |
325 | lhs_type = make_dt(default_fp); |
326 | } |
327 | if (!is_real(rhs_type.get_element_type())) { |
328 | rhs_type = make_dt(default_fp); |
329 | } |
330 | } |
331 | ret_type = promoted_type(lhs_type, rhs_type); |
332 | } |
333 | |
334 | void BinaryOpExpression::flatten(FlattenContext *ctx) { |
335 | // if (stmt) |
336 | // return; |
337 | auto lhs_stmt = flatten_rvalue(lhs, ctx); |
338 | |
339 | if (binary_is_logical(type)) { |
340 | auto result = ctx->push_back<AllocaStmt>(ret_type); |
341 | ctx->push_back<LocalStoreStmt>(result, lhs_stmt); |
342 | auto cond = ctx->push_back<LocalLoadStmt>(result); |
343 | auto if_stmt = ctx->push_back<IfStmt>(cond); |
344 | |
345 | FlattenContext rctx; |
346 | rctx.current_block = ctx->current_block; |
347 | auto rhs_stmt = flatten_rvalue(rhs, &rctx); |
348 | rctx.push_back<LocalStoreStmt>(result, rhs_stmt); |
349 | |
350 | auto true_block = std::make_unique<Block>(); |
351 | if (type == BinaryOpType::logical_and) { |
352 | true_block->set_statements(std::move(rctx.stmts)); |
353 | } |
354 | if_stmt->set_true_statements(std::move(true_block)); |
355 | |
356 | auto false_block = std::make_unique<Block>(); |
357 | if (type == BinaryOpType::logical_or) { |
358 | false_block->set_statements(std::move(rctx.stmts)); |
359 | } |
360 | if_stmt->set_false_statements(std::move(false_block)); |
361 | |
362 | auto ret = ctx->push_back<LocalLoadStmt>(result); |
363 | ret->tb = tb; |
364 | stmt = ret; |
365 | stmt->ret_type = ret_type; |
366 | return; |
367 | } |
368 | auto rhs_stmt = flatten_rvalue(rhs, ctx); |
369 | ctx->push_back(std::make_unique<BinaryOpStmt>(type, lhs_stmt, rhs_stmt)); |
370 | ctx->stmts.back()->tb = tb; |
371 | stmt = ctx->back_stmt(); |
372 | stmt->ret_type = ret_type; |
373 | } |
374 | |
375 | void make_ifte(Expression::FlattenContext *ctx, |
376 | DataType ret_type, |
377 | Expr cond, |
378 | Expr true_val, |
379 | Expr false_val) { |
380 | auto result = ctx->push_back<AllocaStmt>(ret_type); |
381 | auto cond_stmt = flatten_rvalue(cond, ctx); |
382 | auto if_stmt = ctx->push_back<IfStmt>(cond_stmt); |
383 | |
384 | Expression::FlattenContext lctx; |
385 | lctx.current_block = ctx->current_block; |
386 | auto true_val_stmt = flatten_rvalue(true_val, &lctx); |
387 | lctx.push_back<LocalStoreStmt>(result, true_val_stmt); |
388 | |
389 | Expression::FlattenContext rctx; |
390 | rctx.current_block = ctx->current_block; |
391 | auto false_val_stmt = flatten_rvalue(false_val, &rctx); |
392 | rctx.push_back<LocalStoreStmt>(result, false_val_stmt); |
393 | |
394 | auto true_block = std::make_unique<Block>(); |
395 | true_block->set_statements(std::move(lctx.stmts)); |
396 | if_stmt->set_true_statements(std::move(true_block)); |
397 | |
398 | auto false_block = std::make_unique<Block>(); |
399 | false_block->set_statements(std::move(rctx.stmts)); |
400 | if_stmt->set_false_statements(std::move(false_block)); |
401 | |
402 | ctx->push_back<LocalLoadStmt>(result); |
403 | return; |
404 | } |
405 | |
406 | static std::tuple<Expr, Expr, Expr> unify_ternaryop_operands(const Expr &e1, |
407 | const Expr &e2, |
408 | const Expr &e3) { |
409 | auto target_dtype = PrimitiveType::unknown; |
410 | // Since we don't support broadcasting between two TensorTypes, |
411 | // we can simply use the first TensorType's dtype as the target dtype. |
412 | if (e1->ret_type->is<TensorType>()) { |
413 | target_dtype = e1->ret_type; |
414 | } else if (e2->ret_type->is<TensorType>()) { |
415 | target_dtype = e2->ret_type; |
416 | } else if (e3->ret_type->is<TensorType>()) { |
417 | target_dtype = e3->ret_type; |
418 | } |
419 | |
420 | if (target_dtype == PrimitiveType::unknown) { |
421 | return std::tuple(e1, e2, e3); |
422 | } |
423 | |
424 | return std::tuple(to_broadcast_tensor(e1, target_dtype), |
425 | to_broadcast_tensor(e2, target_dtype), |
426 | to_broadcast_tensor(e3, target_dtype)); |
427 | } |
428 | |
429 | void TernaryOpExpression::type_check(const CompileConfig *config) { |
430 | TI_ASSERT_TYPE_CHECKED(op1); |
431 | TI_ASSERT_TYPE_CHECKED(op2); |
432 | TI_ASSERT_TYPE_CHECKED(op3); |
433 | |
434 | bool is_valid = true; |
435 | bool is_tensor = false; |
436 | |
437 | auto [unified_cond, unified_l, unified_r] = |
438 | unify_ternaryop_operands(op1, op2, op3); |
439 | op1 = unified_cond; |
440 | op2 = unified_l; |
441 | op3 = unified_r; |
442 | auto op1_type = op1->ret_type; |
443 | auto op2_type = op2->ret_type; |
444 | auto op3_type = op3->ret_type; |
445 | |
446 | auto error = [&]() { |
447 | throw TaichiTypeError( |
448 | fmt::format("unsupported operand type(s) for '{}': '{}', '{}' and '{}'" , |
449 | ternary_type_name(type), op1->ret_type->to_string(), |
450 | op2->ret_type->to_string(), op3->ret_type->to_string())); |
451 | }; |
452 | |
453 | if (op1_type->is<TensorType>() && op2_type->is<TensorType>() && |
454 | op3_type->is<TensorType>()) { |
455 | // valid |
456 | is_tensor = true; |
457 | if (op1_type->cast<TensorType>()->get_shape() != |
458 | op2_type->cast<TensorType>()->get_shape()) { |
459 | is_valid = false; |
460 | } |
461 | if (op2_type->cast<TensorType>()->get_shape() != |
462 | op3_type->cast<TensorType>()->get_shape()) { |
463 | is_valid = false; |
464 | } |
465 | op1_type = op1_type->cast<TensorType>()->get_element_type(); |
466 | op2_type = op2_type->cast<TensorType>()->get_element_type(); |
467 | op3_type = op3_type->cast<TensorType>()->get_element_type(); |
468 | |
469 | } else if (op1_type->is<PrimitiveType>() && op2_type->is<PrimitiveType>() && |
470 | op3_type->is<PrimitiveType>()) { |
471 | // valid |
472 | } else { |
473 | is_valid = false; |
474 | } |
475 | |
476 | if (op1_type != PrimitiveType::i32) { |
477 | is_valid = false; |
478 | } |
479 | if (!op2_type->is<PrimitiveType>() || !op3_type->is<PrimitiveType>()) { |
480 | is_valid = false; |
481 | } |
482 | |
483 | if (!is_valid) |
484 | error(); |
485 | |
486 | if (is_tensor) { |
487 | auto primitive_dtype = promoted_type(op2_type, op3_type); |
488 | auto shape = op2->ret_type->cast<TensorType>()->get_shape(); |
489 | ret_type = TypeFactory::create_tensor_type(shape, primitive_dtype); |
490 | } else { |
491 | ret_type = promoted_type(op2_type, op3_type); |
492 | } |
493 | } |
494 | |
495 | void TernaryOpExpression::flatten(FlattenContext *ctx) { |
496 | // if (stmt) |
497 | // return; |
498 | if (type == TernaryOpType::select) { |
499 | auto op1_stmt = flatten_rvalue(op1, ctx); |
500 | auto op2_stmt = flatten_rvalue(op2, ctx); |
501 | auto op3_stmt = flatten_rvalue(op3, ctx); |
502 | ctx->push_back( |
503 | std::make_unique<TernaryOpStmt>(type, op1_stmt, op2_stmt, op3_stmt)); |
504 | } else if (type == TernaryOpType::ifte) { |
505 | make_ifte(ctx, ret_type, op1, op2, op3); |
506 | } |
507 | stmt = ctx->back_stmt(); |
508 | stmt->tb = tb; |
509 | stmt->ret_type = ret_type; |
510 | } |
511 | |
512 | void InternalFuncCallExpression::type_check(const CompileConfig *) { |
513 | for (auto &arg : args) { |
514 | TI_ASSERT_TYPE_CHECKED(arg); |
515 | // no arg type compatibility check for now due to lack of specification |
516 | } |
517 | // internal func calls have default return type |
518 | ret_type = PrimitiveType::i32; |
519 | } |
520 | |
521 | void InternalFuncCallExpression::flatten(FlattenContext *ctx) { |
522 | std::vector<Stmt *> args_stmts(args.size()); |
523 | for (int i = 0; i < (int)args.size(); ++i) { |
524 | args_stmts[i] = flatten_rvalue(args[i], ctx); |
525 | } |
526 | ctx->push_back<InternalFuncStmt>(func_name, args_stmts, nullptr, |
527 | with_runtime_context); |
528 | stmt = ctx->back_stmt(); |
529 | stmt->tb = tb; |
530 | } |
531 | |
532 | void ExternalTensorExpression::flatten(FlattenContext *ctx) { |
533 | // https://github.com/taichi-dev/taichi/issues/5819 |
534 | // ArgLoadStmt keeps primitive types since all matrix-type gets |
535 | // scalarized at python-scope |
536 | // |
537 | // FIXME(zhanlue): ArgLoadStmt should use TensorType once real_matrix is |
538 | // turned-on by default. |
539 | // The scalarization should happen after |
540 | // irpass::lower_access() |
541 | auto prim_dt = dt; |
542 | auto ptr = Stmt::make<ArgLoadStmt>(arg_id, prim_dt, /*is_ptr=*/true, |
543 | /*is_grad=*/is_grad); |
544 | |
545 | int external_dims = dim - std::abs(element_dim); |
546 | ptr->cast<ArgLoadStmt>()->set_extern_dims(external_dims); |
547 | |
548 | ptr->tb = tb; |
549 | ctx->push_back(std::move(ptr)); |
550 | stmt = ctx->back_stmt(); |
551 | } |
552 | |
553 | std::vector<Stmt *> make_index_stmts(Expression::FlattenContext *ctx, |
554 | const ExprGroup &indices, |
555 | const std::vector<int> &offsets) { |
556 | std::vector<Stmt *> index_stmts; |
557 | for (int i = 0; i < (int)indices.size(); i++) { |
558 | Stmt *ind = flatten_rvalue(indices.exprs[i], ctx); |
559 | if (!offsets.empty()) { |
560 | auto offset = ctx->push_back<ConstStmt>(TypedConstant(offsets[i])); |
561 | ind = ctx->push_back<BinaryOpStmt>(BinaryOpType::sub, ind, offset); |
562 | } |
563 | index_stmts.push_back(ind); |
564 | } |
565 | return index_stmts; |
566 | } |
567 | |
568 | Stmt *make_field_access(Expression::FlattenContext *ctx, |
569 | const FieldExpression &field, |
570 | ExprGroup indices) { |
571 | return ctx->push_back(std::make_unique<GlobalPtrStmt>( |
572 | field.snode, make_index_stmts(ctx, indices, field.snode->index_offsets))); |
573 | } |
574 | |
575 | Stmt *make_matrix_field_access(Expression::FlattenContext *ctx, |
576 | const MatrixFieldExpression &matrix_field, |
577 | ExprGroup indices, |
578 | DataType ret_type) { |
579 | std::vector<SNode *> snodes; |
580 | for (auto &field : matrix_field.fields) { |
581 | snodes.push_back(field.cast<FieldExpression>()->snode); |
582 | } |
583 | return ctx->push_back(std::make_unique<MatrixOfGlobalPtrStmt>( |
584 | snodes, make_index_stmts(ctx, indices, snodes[0]->index_offsets), |
585 | matrix_field.dynamic_indexable, matrix_field.dynamic_index_stride, |
586 | ret_type)); |
587 | } |
588 | |
589 | Stmt *make_ndarray_access(Expression::FlattenContext *ctx, |
590 | Expr var, |
591 | ExprGroup indices) { |
592 | std::vector<Stmt *> index_stmts; |
593 | for (int i = 0; i < (int)indices.size(); i++) { |
594 | Stmt *ind = flatten_rvalue(indices.exprs[i], ctx); |
595 | index_stmts.push_back(ind); |
596 | } |
597 | auto var_stmt = flatten_lvalue(var, ctx); |
598 | auto expr = var.cast<ExternalTensorExpression>(); |
599 | auto external_ptr_stmt = std::make_unique<ExternalPtrStmt>( |
600 | var_stmt, index_stmts, expr->dt.get_shape(), expr->element_dim); |
601 | if (expr->dim == indices.size()) { |
602 | // Indexing into an scalar element |
603 | external_ptr_stmt->ret_type = expr->dt.ptr_removed().get_element_type(); |
604 | } else { |
605 | // Indexing outer dimensions |
606 | external_ptr_stmt->ret_type = expr->dt.ptr_removed(); |
607 | } |
608 | |
609 | return ctx->push_back(std::move(external_ptr_stmt)); |
610 | } |
611 | |
612 | Stmt *make_tensor_access_single_element(Expression::FlattenContext *ctx, |
613 | Stmt *var_stmt, |
614 | const ExprGroup &indices, |
615 | const std::vector<int> &shape, |
616 | const std::string &tb) { |
617 | bool needs_dynamic_index = false; |
618 | for (int i = 0; i < (int)indices.size(); ++i) { |
619 | if (!indices[i].is<ConstExpression>()) { |
620 | needs_dynamic_index = true; |
621 | } |
622 | } |
623 | Stmt *offset_stmt = nullptr; |
624 | if (needs_dynamic_index) { |
625 | offset_stmt = ctx->push_back<ConstStmt>(TypedConstant(0)); |
626 | for (int i = 0; i < (int)indices.size(); ++i) { |
627 | auto index_stmt = flatten_rvalue(indices[i], ctx); |
628 | Stmt *shape_stmt = ctx->push_back<ConstStmt>(TypedConstant(shape[i])); |
629 | Stmt *mul_stmt = ctx->push_back<BinaryOpStmt>(BinaryOpType::mul, |
630 | offset_stmt, shape_stmt); |
631 | offset_stmt = |
632 | ctx->push_back<BinaryOpStmt>(BinaryOpType::add, mul_stmt, index_stmt); |
633 | } |
634 | } else { |
635 | int offset = 0; |
636 | for (int i = 0; i < (int)indices.size(); ++i) { |
637 | offset = |
638 | offset * shape[i] + indices[i].cast<ConstExpression>()->val.val_int(); |
639 | } |
640 | offset_stmt = ctx->push_back<ConstStmt>(TypedConstant(offset)); |
641 | } |
642 | return ctx->push_back<MatrixPtrStmt>(var_stmt, offset_stmt, tb); |
643 | } |
644 | |
645 | Stmt *make_tensor_access(Expression::FlattenContext *ctx, |
646 | Expr var, |
647 | const std::vector<ExprGroup> &indices_group, |
648 | DataType ret_type, |
649 | std::vector<int> shape, |
650 | const std::string &tb) { |
651 | auto var_stmt = flatten_lvalue(var, ctx); |
652 | if (!var->is_lvalue()) { |
653 | auto alloca_stmt = ctx->push_back<AllocaStmt>(var->ret_type); |
654 | ctx->push_back<LocalStoreStmt>(alloca_stmt, var_stmt); |
655 | var_stmt = alloca_stmt; |
656 | } |
657 | if (is_tensor(ret_type)) { |
658 | std::vector<Stmt *> stmts; |
659 | for (auto &indices : indices_group) { |
660 | stmts.push_back( |
661 | make_tensor_access_single_element(ctx, var_stmt, indices, shape, tb)); |
662 | } |
663 | return ctx->push_back<MatrixOfMatrixPtrStmt>(stmts, ret_type); |
664 | } |
665 | return make_tensor_access_single_element(ctx, var_stmt, indices_group[0], |
666 | shape, tb); |
667 | } |
668 | |
669 | void MatrixExpression::type_check(const CompileConfig *config) { |
670 | TI_ASSERT(dt->as<TensorType>()->get_num_elements() == elements.size()); |
671 | |
672 | for (auto &arg : elements) { |
673 | TI_ASSERT_TYPE_CHECKED(arg); |
674 | if (arg->ret_type != dt.get_element_type()) { |
675 | arg = cast(arg, dt.get_element_type()); |
676 | arg->type_check(config); |
677 | } |
678 | } |
679 | ret_type = dt; |
680 | } |
681 | |
682 | void MatrixExpression::flatten(FlattenContext *ctx) { |
683 | TI_ASSERT(this->dt->is<TensorType>()); |
684 | std::vector<Stmt *> values; |
685 | for (auto &elt : elements) { |
686 | values.push_back(flatten_rvalue(elt, ctx)); |
687 | } |
688 | stmt = ctx->push_back<MatrixInitStmt>(values); |
689 | stmt->ret_type = this->dt; |
690 | } |
691 | |
692 | IndexExpression::IndexExpression(const Expr &var, |
693 | const ExprGroup &indices, |
694 | std::string tb) |
695 | : var(var), indices_group({indices}) { |
696 | this->tb = tb; |
697 | } |
698 | |
699 | IndexExpression::IndexExpression(const Expr &var, |
700 | const std::vector<ExprGroup> &indices_group, |
701 | const std::vector<int> &ret_shape, |
702 | std::string tb) |
703 | : var(var), indices_group(indices_group), ret_shape(ret_shape) { |
704 | // IndexExpression with ret_shape is used for matrix slicing, where each entry |
705 | // of ExprGroup is interpreted as a group of indices to return within each |
706 | // axis. For example, mat[0, 3:5] has indices_group={0, [3, 4]}, where [3, 4] |
707 | // means "m"-axis will return a TensorType with size of 2. In this case, we |
708 | // should not expand indices_group due to its special semantics. |
709 | this->tb = tb; |
710 | } |
711 | |
712 | bool IndexExpression::is_field() const { |
713 | return var.is<FieldExpression>(); |
714 | } |
715 | |
716 | bool IndexExpression::is_matrix_field() const { |
717 | return var.is<MatrixFieldExpression>(); |
718 | } |
719 | |
720 | bool IndexExpression::is_ndarray() const { |
721 | return var.is<ExternalTensorExpression>(); |
722 | } |
723 | |
724 | bool IndexExpression::is_tensor() const { |
725 | return var->ret_type->is<TensorType>(); |
726 | } |
727 | |
728 | bool IndexExpression::is_local() const { |
729 | return !is_global(); |
730 | } |
731 | |
732 | bool IndexExpression::is_global() const { |
733 | // Special case: Indexing into TensorType-element of ExternalPtrStmt |
734 | // or GlobalPtrStmt should be treated as global ptrs |
735 | if (var.is<IndexExpression>()) { |
736 | TI_ASSERT(var.cast<IndexExpression>()->is_matrix_field() || |
737 | var.cast<IndexExpression>()->is_ndarray()); |
738 | return true; |
739 | } |
740 | |
741 | // Only Ndarray and Field comes outside from a kernel |
742 | return is_field() || is_matrix_field() || is_ndarray(); |
743 | } |
744 | |
745 | static void field_validation(FieldExpression *field_expr, int index_dim) { |
746 | TI_ASSERT(field_expr != nullptr); |
747 | TI_ASSERT(field_expr->snode != nullptr); |
748 | int field_dim = field_expr->snode->num_active_indices; |
749 | |
750 | if (field_dim != index_dim) { |
751 | throw TaichiIndexError( |
752 | fmt::format("Field with dim {} accessed with indices of dim {}" , |
753 | field_dim, index_dim)); |
754 | } |
755 | } |
756 | |
757 | void IndexExpression::type_check(const CompileConfig *) { |
758 | // TODO: Change to type-based solution |
759 | // Currently, dimension compatibility check happens in Python |
760 | TI_ASSERT(indices_group.size() == std::accumulate(begin(ret_shape), |
761 | end(ret_shape), 1, |
762 | std::multiplies<>())); |
763 | int index_dim = indices_group.empty() ? 0 : indices_group[0].size(); |
764 | bool has_slice = !ret_shape.empty(); |
765 | if (has_slice) { |
766 | TI_ASSERT_INFO(is_tensor(), "Slice or swizzle can only apply on matrices" ); |
767 | auto element_type = var->ret_type->as<TensorType>()->get_element_type(); |
768 | ret_type = TypeFactory::create_tensor_type(ret_shape, element_type); |
769 | |
770 | } else if (is_field()) { // field |
771 | auto field_expr = var.cast<FieldExpression>(); |
772 | field_validation(field_expr.get(), index_dim); |
773 | ret_type = field_expr->dt->get_compute_type(); |
774 | |
775 | } else if (is_matrix_field()) { |
776 | auto matrix_field_expr = var.cast<MatrixFieldExpression>(); |
777 | |
778 | TI_ASSERT(!matrix_field_expr->fields.empty()); |
779 | auto field_expr = matrix_field_expr->fields[0].cast<FieldExpression>(); |
780 | field_validation(field_expr.get(), index_dim); |
781 | |
782 | ret_type = TypeFactory::create_tensor_type(matrix_field_expr->element_shape, |
783 | matrix_field_expr->fields[0] |
784 | .cast<FieldExpression>() |
785 | ->dt->get_compute_type()); |
786 | } else if (is_ndarray()) { // ndarray |
787 | auto external_tensor_expr = var.cast<ExternalTensorExpression>(); |
788 | int total_dim = external_tensor_expr->dim; |
789 | int element_dim = external_tensor_expr->dt.get_shape().size(); |
790 | if (total_dim != index_dim + element_dim) { |
791 | throw TaichiTypeError( |
792 | fmt::format("Array with dim {} accessed with indices of dim {}" , |
793 | total_dim - element_dim, index_dim)); |
794 | } |
795 | |
796 | if (index_dim == total_dim) { |
797 | // Access all the way to a single element |
798 | ret_type = var.cast<ExternalTensorExpression>()->dt.get_element_type(); |
799 | } else { |
800 | // Access to a Tensor |
801 | ret_type = var.cast<ExternalTensorExpression>()->dt; |
802 | } |
803 | } else if (is_tensor()) { // local tensor |
804 | auto shape = var->ret_type->as<TensorType>()->get_shape(); |
805 | if (indices_group[0].size() != shape.size()) { |
806 | TI_ERROR("Expected {} indices, got {}." , shape.size(), |
807 | indices_group[0].size()); |
808 | } |
809 | ret_type = var->ret_type->cast<TensorType>()->get_element_type(); |
810 | } else { |
811 | throw TaichiTypeError( |
812 | "Invalid IndexExpression: the source is not among field, ndarray or " |
813 | "local tensor" ); |
814 | } |
815 | |
816 | for (auto &indices : indices_group) { |
817 | for (int i = 0; i < indices.exprs.size(); i++) { |
818 | auto &expr = indices.exprs[i]; |
819 | TI_ASSERT_TYPE_CHECKED(expr); |
820 | if (!is_integral(expr->ret_type)) |
821 | throw TaichiTypeError( |
822 | fmt::format("indices must be integers, however '{}' is " |
823 | "provided as index {}" , |
824 | expr->ret_type->to_string(), i)); |
825 | } |
826 | } |
827 | } |
828 | |
829 | void IndexExpression::flatten(FlattenContext *ctx) { |
830 | if (is_field()) { |
831 | stmt = |
832 | make_field_access(ctx, *var.cast<FieldExpression>(), indices_group[0]); |
833 | } else if (is_matrix_field()) { |
834 | stmt = make_matrix_field_access(ctx, *var.cast<MatrixFieldExpression>(), |
835 | indices_group[0], ret_type); |
836 | } else if (is_ndarray()) { |
837 | stmt = make_ndarray_access(ctx, var, indices_group[0]); |
838 | } else if (is_tensor()) { |
839 | stmt = |
840 | make_tensor_access(ctx, var, indices_group, ret_type, |
841 | var->ret_type->cast<TensorType>()->get_shape(), tb); |
842 | } else { |
843 | throw TaichiTypeError( |
844 | "Invalid IndexExpression: the source is not among field, ndarray or " |
845 | "local tensor" ); |
846 | } |
847 | stmt->tb = tb; |
848 | } |
849 | |
850 | void RangeAssumptionExpression::type_check(const CompileConfig *) { |
851 | TI_ASSERT_TYPE_CHECKED(input); |
852 | TI_ASSERT_TYPE_CHECKED(base); |
853 | if (!input->ret_type->is<PrimitiveType>() || |
854 | !base->ret_type->is<PrimitiveType>() || input->ret_type != base->ret_type) |
855 | throw TaichiTypeError( |
856 | fmt::format("unsupported operand type(s) for " |
857 | "'range_assumption': '{}' and '{}'" , |
858 | input->ret_type->to_string(), base->ret_type->to_string())); |
859 | ret_type = input->ret_type; |
860 | } |
861 | |
862 | void RangeAssumptionExpression::flatten(FlattenContext *ctx) { |
863 | auto input_stmt = flatten_rvalue(input, ctx); |
864 | auto base_stmt = flatten_rvalue(base, ctx); |
865 | ctx->push_back( |
866 | Stmt::make<RangeAssumptionStmt>(input_stmt, base_stmt, low, high)); |
867 | stmt = ctx->back_stmt(); |
868 | } |
869 | |
870 | void LoopUniqueExpression::type_check(const CompileConfig *) { |
871 | TI_ASSERT_TYPE_CHECKED(input); |
872 | if (!input->ret_type->is<PrimitiveType>()) |
873 | throw TaichiTypeError( |
874 | fmt::format("unsupported operand type(s) for 'loop_unique': '{}'" , |
875 | input->ret_type->to_string())); |
876 | ret_type = input->ret_type; |
877 | } |
878 | |
879 | void LoopUniqueExpression::flatten(FlattenContext *ctx) { |
880 | auto input_stmt = flatten_rvalue(input, ctx); |
881 | ctx->push_back(Stmt::make<LoopUniqueStmt>(input_stmt, covers)); |
882 | stmt = ctx->back_stmt(); |
883 | } |
884 | |
885 | void IdExpression::flatten(FlattenContext *ctx) { |
886 | stmt = ctx->current_block->lookup_var(id); |
887 | if (!ret_type->is_primitive(PrimitiveTypeID::unknown)) { |
888 | stmt->ret_type = ret_type; |
889 | } |
890 | } |
891 | |
892 | void AtomicOpExpression::type_check(const CompileConfig *config) { |
893 | TI_ASSERT_TYPE_CHECKED(dest); |
894 | TI_ASSERT_TYPE_CHECKED(val); |
895 | auto error = [&]() { |
896 | throw TaichiTypeError(fmt::format( |
897 | "unsupported operand type(s) for 'atomic_{}': '{}' and '{}'" , |
898 | atomic_op_type_name(op_type), dest->ret_type->to_string(), |
899 | val->ret_type->to_string())); |
900 | }; |
901 | |
902 | // Broadcast val to dest if neccessary |
903 | auto val_dtype = val->ret_type; |
904 | auto dest_dtype = dest->ret_type.ptr_removed(); |
905 | if (dest_dtype->is<PrimitiveType>() && val_dtype->is<TensorType>()) { |
906 | error(); |
907 | } |
908 | |
909 | if (val_dtype->is<PrimitiveType>() && dest_dtype->is<TensorType>()) { |
910 | auto broadcasted_expr = to_broadcast_tensor(val, dest_dtype); |
911 | val = std::move(broadcasted_expr); |
912 | val.type_check(config); |
913 | } |
914 | |
915 | // Validate dtype |
916 | auto dtype = val->ret_type; |
917 | if (dtype->is<TensorType>()) { |
918 | dtype = dtype.get_element_type(); |
919 | } |
920 | |
921 | if (!dtype->is<PrimitiveType>()) { |
922 | error(); |
923 | } |
924 | |
925 | if (is_quant(dest->ret_type)) { |
926 | ret_type = dest->ret_type->get_compute_type(); |
927 | } else if (dest->ret_type->is<PrimitiveType>() || |
928 | dest->ret_type->is<TensorType>()) { |
929 | ret_type = dest->ret_type; |
930 | } else { |
931 | error(); |
932 | } |
933 | } |
934 | |
935 | void AtomicOpExpression::flatten(FlattenContext *ctx) { |
936 | TI_ASSERT( |
937 | dest.is<IdExpression>() || dest.is<IndexExpression>() || |
938 | (dest.is<ArgLoadExpression>() && dest.cast<ArgLoadExpression>()->is_ptr)); |
939 | // replace atomic sub with negative atomic add |
940 | if (op_type == AtomicOpType::sub) { |
941 | if (val->ret_type != ret_type) { |
942 | val.set(Expr::make<UnaryOpExpression>(UnaryOpType::cast_value, val, |
943 | ret_type)); |
944 | } |
945 | |
946 | val.set(Expr::make<UnaryOpExpression>(UnaryOpType::neg, val)); |
947 | op_type = AtomicOpType::add; |
948 | } |
949 | // expand rhs |
950 | auto val_stmt = flatten_rvalue(val, ctx); |
951 | auto dest_stmt = flatten_lvalue(dest, ctx); |
952 | stmt = ctx->push_back<AtomicOpStmt>(op_type, dest_stmt, val_stmt); |
953 | stmt->ret_type = stmt->as<AtomicOpStmt>()->dest->ret_type; |
954 | stmt->tb = tb; |
955 | } |
956 | |
957 | SNodeOpExpression::SNodeOpExpression(SNode *snode, |
958 | SNodeOpType op_type, |
959 | const ExprGroup &indices) |
960 | : snode(snode), op_type(op_type), indices(indices) { |
961 | } |
962 | |
963 | SNodeOpExpression::SNodeOpExpression(SNode *snode, |
964 | SNodeOpType op_type, |
965 | const ExprGroup &indices, |
966 | const std::vector<Expr> &values) |
967 | : SNodeOpExpression(snode, op_type, indices) { |
968 | this->values = values; |
969 | } |
970 | |
971 | void SNodeOpExpression::type_check(const CompileConfig *config) { |
972 | if (op_type == SNodeOpType::get_addr) { |
973 | ret_type = PrimitiveType::u64; |
974 | } else { |
975 | ret_type = PrimitiveType::i32; |
976 | } |
977 | if (op_type == SNodeOpType::append) { |
978 | TI_ASSERT(snode->ch.size() == values.size()); |
979 | for (int i = 0; i < values.size(); i++) { |
980 | TI_ASSERT_TYPE_CHECKED(values[i]); |
981 | auto &dst_type = snode->ch[i]->dt; |
982 | auto promoted = promoted_type(dst_type, values[i]->ret_type); |
983 | if (dst_type != promoted) { |
984 | TI_WARN("Append may lose precision: {} <- {}\n{}" , |
985 | dst_type->to_string(), values[i]->ret_type->to_string(), tb); |
986 | } |
987 | values[i] = cast(values[i], dst_type); |
988 | values[i]->type_check(config); |
989 | } |
990 | } |
991 | } |
992 | |
993 | void SNodeOpExpression::flatten(FlattenContext *ctx) { |
994 | std::vector<Stmt *> indices_stmt; |
995 | for (int i = 0; i < (int)indices.size(); i++) { |
996 | indices_stmt.push_back(flatten_rvalue(indices[i], ctx)); |
997 | } |
998 | auto is_cell_access = SNodeOpStmt::activation_related(op_type) && |
999 | snode->type != SNodeType::dynamic; |
1000 | auto ptr = |
1001 | ctx->push_back<GlobalPtrStmt>(snode, indices_stmt, true, is_cell_access); |
1002 | ptr->tb = tb; |
1003 | if (op_type == SNodeOpType::is_active) { |
1004 | TI_ERROR_IF(snode->type != SNodeType::pointer && |
1005 | snode->type != SNodeType::hash && |
1006 | snode->type != SNodeType::bitmasked, |
1007 | "ti.is_active only works on pointer, hash or bitmasked nodes." ); |
1008 | ctx->push_back<SNodeOpStmt>(SNodeOpType::is_active, snode, ptr, nullptr); |
1009 | } else if (op_type == SNodeOpType::length) { |
1010 | ctx->push_back<SNodeOpStmt>(SNodeOpType::length, snode, ptr, nullptr); |
1011 | } else if (op_type == SNodeOpType::get_addr) { |
1012 | ctx->push_back<SNodeOpStmt>(SNodeOpType::get_addr, snode, ptr, nullptr); |
1013 | } else if (op_type == SNodeOpType::append) { |
1014 | auto alloca = ctx->push_back<AllocaStmt>(PrimitiveType::i32); |
1015 | alloca->set_tb(tb); |
1016 | auto addr = |
1017 | ctx->push_back<SNodeOpStmt>(SNodeOpType::allocate, snode, ptr, alloca); |
1018 | addr->set_tb(tb); |
1019 | for (int i = 0; i < values.size(); i++) { |
1020 | auto value_stmt = flatten_rvalue(values[i], ctx); |
1021 | auto ch_addr = ctx->push_back<GetChStmt>(addr, snode, i); |
1022 | ch_addr->set_tb(tb); |
1023 | ctx->push_back<GlobalStoreStmt>(ch_addr, value_stmt)->set_tb(tb); |
1024 | } |
1025 | ctx->push_back<LocalLoadStmt>(alloca)->set_tb(tb); |
1026 | TI_ERROR_IF(snode->type != SNodeType::dynamic, |
1027 | "ti.append only works on dynamic nodes." ); |
1028 | } |
1029 | stmt = ctx->back_stmt(); |
1030 | } |
1031 | |
1032 | TextureOpExpression::TextureOpExpression(TextureOpType op, |
1033 | Expr texture_ptr, |
1034 | const ExprGroup &args) |
1035 | : op(op), texture_ptr(texture_ptr), args(args) { |
1036 | } |
1037 | |
1038 | void TextureOpExpression::type_check(const CompileConfig *config) { |
1039 | TI_ASSERT(texture_ptr.is<TexturePtrExpression>()); |
1040 | auto ptr = texture_ptr.cast<TexturePtrExpression>(); |
1041 | if (op == TextureOpType::kSampleLod) { |
1042 | // UV, Lod |
1043 | TI_ASSERT_INFO(args.size() == ptr->num_dims + 1, |
1044 | "Invalid number of args for sample_lod Texture op with a " |
1045 | "{}-dimension texture" , |
1046 | ptr->num_dims); |
1047 | for (int i = 0; i < ptr->num_dims; i++) { |
1048 | TI_ASSERT_TYPE_CHECKED(args[i]); |
1049 | if (args[i].get_ret_type() != PrimitiveType::f32) { |
1050 | throw TaichiTypeError( |
1051 | fmt::format("Invalid type for texture sample_lod: '{}', all " |
1052 | "arguments must be f32" , |
1053 | args[i].get_ret_type()->to_string())); |
1054 | } |
1055 | } |
1056 | } else if (op == TextureOpType::kFetchTexel) { |
1057 | // index, int LOD |
1058 | TI_ASSERT_INFO(args.size() == ptr->num_dims + 1, |
1059 | "Invalid number of args for fetch_texel Texture op with a " |
1060 | "{}-dimension texture" , |
1061 | ptr->num_dims); |
1062 | for (int i = 0; i < ptr->num_dims; i++) { |
1063 | TI_ASSERT_TYPE_CHECKED(args[i]); |
1064 | if (args[i].get_ret_type() != PrimitiveType::i32) { |
1065 | throw TaichiTypeError( |
1066 | fmt::format("Invalid type for texture fetch_texel: '{}', all " |
1067 | "arguments must be i32" , |
1068 | args[i].get_ret_type()->to_string())); |
1069 | } |
1070 | } |
1071 | } else if (op == TextureOpType::kLoad) { |
1072 | // index |
1073 | TI_ASSERT_INFO(args.size() == ptr->num_dims, |
1074 | "Invalid number of args for load Texture op with a " |
1075 | "{}-dimension texture" , |
1076 | ptr->num_dims); |
1077 | for (int i = 0; i < ptr->num_dims; i++) { |
1078 | TI_ASSERT_TYPE_CHECKED(args[i]); |
1079 | if (args[i].get_ret_type() != PrimitiveType::i32) { |
1080 | throw TaichiTypeError( |
1081 | fmt::format("Invalid type for texture load: '{}', all " |
1082 | "arguments must be i32" , |
1083 | args[i].get_ret_type()->to_string())); |
1084 | } |
1085 | } |
1086 | } else if (op == TextureOpType::kStore) { |
1087 | // index, value |
1088 | TI_ASSERT_INFO(args.size() == ptr->num_dims + 4, |
1089 | "Invalid number of args for store Texture op with a " |
1090 | "{}-dimension texture" , |
1091 | ptr->num_dims); |
1092 | for (int i = 0; i < ptr->num_dims; i++) { |
1093 | TI_ASSERT_TYPE_CHECKED(args[i]); |
1094 | if (args[i].get_ret_type() != PrimitiveType::i32) { |
1095 | throw TaichiTypeError( |
1096 | fmt::format("Invalid type for texture load: '{}', index " |
1097 | "arguments must be i32" , |
1098 | args[i].get_ret_type()->to_string())); |
1099 | } |
1100 | } |
1101 | for (int i = ptr->num_dims; i < ptr->num_dims + 4; i++) { |
1102 | TI_ASSERT_TYPE_CHECKED(args[i]); |
1103 | if (args[i].get_ret_type() != PrimitiveType::f32) { |
1104 | throw TaichiTypeError( |
1105 | fmt::format("Invalid type for texture load: '{}', value " |
1106 | "arguments must be f32" , |
1107 | args[i].get_ret_type()->to_string())); |
1108 | } |
1109 | } |
1110 | } else { |
1111 | TI_ERROR("Invalid TextureOpType" ); |
1112 | } |
1113 | ret_type = |
1114 | TypeFactory::get_instance().get_pointer_type(PrimitiveType::f32, |
1115 | /*is_bit_pointer=*/false); |
1116 | } |
1117 | |
1118 | void TextureOpExpression::flatten(FlattenContext *ctx) { |
1119 | auto texture_ptr_stmt = flatten_rvalue(texture_ptr, ctx); |
1120 | std::vector<Stmt *> arg_stmts; |
1121 | for (Expr &arg : args.exprs) { |
1122 | arg_stmts.push_back(flatten_rvalue(arg, ctx)); |
1123 | } |
1124 | ctx->push_back<TextureOpStmt>(op, texture_ptr_stmt, arg_stmts); |
1125 | stmt = ctx->back_stmt(); |
1126 | } |
1127 | |
1128 | void ConstExpression::type_check(const CompileConfig *) { |
1129 | TI_ASSERT_INFO( |
1130 | val.dt->is<PrimitiveType>() && val.dt != PrimitiveType::unknown, |
1131 | "Invalid dt [{}] for ConstExpression" , val.dt->to_string()); |
1132 | ret_type = val.dt; |
1133 | } |
1134 | |
1135 | void ConstExpression::flatten(FlattenContext *ctx) { |
1136 | ctx->push_back(Stmt::make<ConstStmt>(val)); |
1137 | stmt = ctx->back_stmt(); |
1138 | } |
1139 | |
1140 | void ExternalTensorShapeAlongAxisExpression::type_check(const CompileConfig *) { |
1141 | TI_ASSERT_INFO( |
1142 | ptr.is<ExternalTensorExpression>() || ptr.is<TexturePtrExpression>(), |
1143 | "Invalid ptr [{}] for ExternalTensorShapeAlongAxisExpression" , |
1144 | ExpressionHumanFriendlyPrinter::expr_to_string(ptr)); |
1145 | ret_type = PrimitiveType::i32; |
1146 | } |
1147 | |
1148 | void ExternalTensorShapeAlongAxisExpression::flatten(FlattenContext *ctx) { |
1149 | auto temp = ptr.cast<ExternalTensorExpression>(); |
1150 | TI_ASSERT(0 <= axis && axis < temp->dim); |
1151 | ctx->push_back<ExternalTensorShapeAlongAxisStmt>(axis, temp->arg_id); |
1152 | stmt = ctx->back_stmt(); |
1153 | } |
1154 | |
1155 | void GetElementExpression::type_check(const CompileConfig *config) { |
1156 | TI_ASSERT_TYPE_CHECKED(src); |
1157 | |
1158 | ret_type = src->ret_type->as<StructType>()->get_element_type(index); |
1159 | } |
1160 | |
1161 | void GetElementExpression::flatten(FlattenContext *ctx) { |
1162 | ctx->push_back<GetElementStmt>(flatten_rvalue(src, ctx), index); |
1163 | stmt = ctx->back_stmt(); |
1164 | } |
1165 | // Mesh related. |
1166 | |
1167 | void MeshPatchIndexExpression::flatten(FlattenContext *ctx) { |
1168 | auto pid_stmt = std::make_unique<MeshPatchIndexStmt>(); |
1169 | ctx->push_back(std::move(pid_stmt)); |
1170 | stmt = ctx->back_stmt(); |
1171 | } |
1172 | |
1173 | void MeshPatchIndexExpression::type_check(const CompileConfig *) { |
1174 | ret_type = PrimitiveType::i32; |
1175 | } |
1176 | |
1177 | void MeshRelationAccessExpression::type_check(const CompileConfig *) { |
1178 | ret_type = PrimitiveType::i32; |
1179 | } |
1180 | |
1181 | void MeshRelationAccessExpression::flatten(FlattenContext *ctx) { |
1182 | auto mesh_idx_stmt = flatten_rvalue(mesh_idx, ctx); |
1183 | if (neighbor_idx) { |
1184 | auto neighbor_idx_stmt = flatten_rvalue(neighbor_idx, ctx); |
1185 | ctx->push_back<MeshRelationAccessStmt>(mesh, mesh_idx_stmt, to_type, |
1186 | neighbor_idx_stmt); |
1187 | } else { |
1188 | ctx->push_back<MeshRelationAccessStmt>(mesh, mesh_idx_stmt, to_type); |
1189 | } |
1190 | stmt = ctx->back_stmt(); |
1191 | } |
1192 | |
1193 | MeshIndexConversionExpression::MeshIndexConversionExpression( |
1194 | mesh::Mesh *mesh, |
1195 | mesh::MeshElementType idx_type, |
1196 | const Expr idx, |
1197 | mesh::ConvType conv_type) |
1198 | : mesh(mesh), idx_type(idx_type), idx(idx), conv_type(conv_type) { |
1199 | } |
1200 | |
1201 | void MeshIndexConversionExpression::type_check(const CompileConfig *) { |
1202 | ret_type = PrimitiveType::i32; |
1203 | } |
1204 | |
1205 | void MeshIndexConversionExpression::flatten(FlattenContext *ctx) { |
1206 | auto idx_stmt = flatten_rvalue(idx, ctx); |
1207 | ctx->push_back<MeshIndexConversionStmt>(mesh, idx_type, idx_stmt, conv_type); |
1208 | stmt = ctx->back_stmt(); |
1209 | } |
1210 | |
1211 | void ReferenceExpression::type_check(const CompileConfig *) { |
1212 | ret_type = var->ret_type; |
1213 | } |
1214 | |
1215 | void ReferenceExpression::flatten(FlattenContext *ctx) { |
1216 | auto var_stmt = flatten_lvalue(var, ctx); |
1217 | ctx->push_back<ReferenceStmt>(var_stmt); |
1218 | stmt = ctx->back_stmt(); |
1219 | } |
1220 | |
1221 | Block *ASTBuilder::current_block() { |
1222 | if (stack_.empty()) |
1223 | return nullptr; |
1224 | else |
1225 | return stack_.back(); |
1226 | } |
1227 | |
1228 | Stmt *ASTBuilder::get_last_stmt() { |
1229 | TI_ASSERT(!stack_.empty()); |
1230 | return stack_.back()->back(); |
1231 | } |
1232 | |
1233 | void ASTBuilder::insert(std::unique_ptr<Stmt> &&stmt, int location) { |
1234 | TI_ASSERT(!stack_.empty()); |
1235 | stack_.back()->insert(std::move(stmt), location); |
1236 | } |
1237 | |
1238 | void ASTBuilder::stop_gradient(SNode *snode) { |
1239 | TI_ASSERT(!stack_.empty()); |
1240 | stack_.back()->stop_gradients.push_back(snode); |
1241 | } |
1242 | |
1243 | void ASTBuilder::insert_assignment(Expr &lhs, |
1244 | const Expr &rhs, |
1245 | const std::string &tb) { |
1246 | // Inside a kernel or a function |
1247 | // Create an assignment in the IR |
1248 | if (lhs.expr == nullptr) { |
1249 | lhs.set(rhs); |
1250 | } else if (lhs.expr->is_lvalue()) { |
1251 | auto stmt = std::make_unique<FrontendAssignStmt>(lhs, rhs); |
1252 | stmt->tb = tb; |
1253 | this->insert(std::move(stmt)); |
1254 | |
1255 | } else { |
1256 | TI_ERROR("Cannot assign to non-lvalue: {}" , |
1257 | ExpressionHumanFriendlyPrinter::expr_to_string(lhs)); |
1258 | } |
1259 | } |
1260 | |
1261 | Expr ASTBuilder::make_var(const Expr &x, std::string tb) { |
1262 | auto var = this->expr_alloca(); |
1263 | this->insert_assignment(var, x, tb); |
1264 | return var; |
1265 | } |
1266 | |
1267 | Expr ASTBuilder::make_id_expr(const std::string &name) { |
1268 | return Expr::make<IdExpression>(get_next_id(name)); |
1269 | } |
1270 | |
1271 | void ASTBuilder::insert_for(const Expr &s, |
1272 | const Expr &e, |
1273 | const std::function<void(Expr)> &func) { |
1274 | auto i = Expr(std::make_shared<IdExpression>(get_next_id())); |
1275 | auto stmt_unique = std::make_unique<FrontendForStmt>(i, s, e, this->arch_, |
1276 | for_loop_dec_.config); |
1277 | for_loop_dec_.reset(); |
1278 | auto stmt = stmt_unique.get(); |
1279 | this->insert(std::move(stmt_unique)); |
1280 | this->create_scope(stmt->body); |
1281 | func(i); |
1282 | this->pop_scope(); |
1283 | } |
1284 | |
1285 | Expr ASTBuilder::insert_thread_idx_expr() { |
1286 | auto loop = stack_.size() ? stack_.back()->parent_stmt : nullptr; |
1287 | TI_ERROR_IF( |
1288 | arch_ != Arch::cuda && !arch_is_cpu(arch_) && arch_ != Arch::amdgpu, |
1289 | "ti.thread_idx() is only available in cuda or cpu or amdgpu context." ); |
1290 | if (loop != nullptr) { |
1291 | auto i = stack_.size() - 1; |
1292 | while (!(loop->is<FrontendForStmt>())) { |
1293 | loop = i > 0 ? stack_[--i]->parent_stmt : nullptr; |
1294 | if (loop == nullptr) |
1295 | break; |
1296 | } |
1297 | } |
1298 | TI_ERROR_IF(!(loop && loop->is<FrontendForStmt>()), |
1299 | "ti.thread_idx() is only valid within loops." ); |
1300 | return Expr::make<InternalFuncCallExpression>( |
1301 | "linear_thread_idx" , std::vector<Expr>{}, /*with_runtime_context=*/true); |
1302 | } |
1303 | |
1304 | Expr ASTBuilder::insert_patch_idx_expr() { |
1305 | auto loop = stack_.size() ? stack_.back()->parent_stmt : nullptr; |
1306 | if (loop != nullptr) { |
1307 | auto i = stack_.size() - 1; |
1308 | while (!(loop->is<FrontendForStmt>())) { |
1309 | loop = i > 0 ? stack_[--i]->parent_stmt : nullptr; |
1310 | if (loop == nullptr) |
1311 | break; |
1312 | } |
1313 | } |
1314 | TI_ERROR_IF(!(loop && loop->is<FrontendForStmt>() && |
1315 | loop->as<FrontendForStmt>()->mesh), |
1316 | "ti.mesh_patch_idx() is only valid within mesh-for loops." ); |
1317 | return Expr::make<MeshPatchIndexExpression>(); |
1318 | } |
1319 | |
1320 | void ASTBuilder::create_kernel_exprgroup_return(const ExprGroup &group) { |
1321 | auto expanded_exprs = this->expand_exprs(group.exprs); |
1322 | ExprGroup expanded_expr_group; |
1323 | expanded_expr_group.exprs = std::move(expanded_exprs); |
1324 | this->insert(Stmt::make<FrontendReturnStmt>(expanded_expr_group)); |
1325 | } |
1326 | |
1327 | void ASTBuilder::create_print( |
1328 | std::vector<std::variant<Expr, std::string>> contents) { |
1329 | this->insert(std::make_unique<FrontendPrintStmt>(contents)); |
1330 | } |
1331 | |
1332 | void ASTBuilder::begin_func(const std::string &funcid) { |
1333 | auto stmt_unique = std::make_unique<FrontendFuncDefStmt>(funcid); |
1334 | auto stmt = stmt_unique.get(); |
1335 | this->insert(std::move(stmt_unique)); |
1336 | this->create_scope(stmt->body); |
1337 | } |
1338 | |
1339 | void ASTBuilder::end_func(const std::string &funcid) { |
1340 | this->pop_scope(); |
1341 | } |
1342 | |
1343 | void ASTBuilder::begin_frontend_if(const Expr &cond) { |
1344 | auto stmt_tmp = std::make_unique<FrontendIfStmt>(cond); |
1345 | this->insert(std::move(stmt_tmp)); |
1346 | } |
1347 | |
1348 | void ASTBuilder::begin_frontend_if_true() { |
1349 | auto if_stmt = this->get_last_stmt()->as<FrontendIfStmt>(); |
1350 | this->create_scope(if_stmt->true_statements); |
1351 | } |
1352 | |
1353 | void ASTBuilder::begin_frontend_if_false() { |
1354 | auto if_stmt = this->get_last_stmt()->as<FrontendIfStmt>(); |
1355 | this->create_scope(if_stmt->false_statements); |
1356 | } |
1357 | |
1358 | void ASTBuilder::insert_external_func_call(std::size_t func_addr, |
1359 | std::string source, |
1360 | std::string filename, |
1361 | std::string funcname, |
1362 | const ExprGroup &args, |
1363 | const ExprGroup &outputs) { |
1364 | auto stmt = Stmt::make<FrontendExternalFuncStmt>( |
1365 | (void *)func_addr, source, filename, funcname, args.exprs, outputs.exprs); |
1366 | this->insert(std::move(stmt)); |
1367 | } |
1368 | |
1369 | Expr ASTBuilder::expr_alloca() { |
1370 | auto var = Expr(std::make_shared<IdExpression>(get_next_id())); |
1371 | this->insert(std::make_unique<FrontendAllocaStmt>( |
1372 | std::static_pointer_cast<IdExpression>(var.expr)->id, |
1373 | PrimitiveType::unknown)); |
1374 | return var; |
1375 | } |
1376 | |
1377 | std::optional<Expr> ASTBuilder::insert_func_call(Function *func, |
1378 | const ExprGroup &args) { |
1379 | ExprGroup expanded_args; |
1380 | expanded_args.exprs = this->expand_exprs(args.exprs); |
1381 | if (func->ret_type) { |
1382 | auto var = Expr(std::make_shared<IdExpression>(get_next_id())); |
1383 | this->insert(std::make_unique<FrontendFuncCallStmt>( |
1384 | func, expanded_args, |
1385 | std::static_pointer_cast<IdExpression>(var.expr)->id)); |
1386 | var.expr->ret_type = func->ret_type; |
1387 | return var; |
1388 | } else { |
1389 | this->insert(std::make_unique<FrontendFuncCallStmt>(func, expanded_args)); |
1390 | return std::nullopt; |
1391 | } |
1392 | } |
1393 | |
1394 | Expr ASTBuilder::make_matrix_expr(const std::vector<int> &shape, |
1395 | const DataType &dt, |
1396 | const std::vector<Expr> &elements) { |
1397 | /* |
1398 | Since we have both "shape" and "element_type" in MatrixExpression, |
1399 | we should flatten all the elements and disallow recursive TensorType in |
1400 | element Expr |
1401 | */ |
1402 | TI_ASSERT(dt->is<PrimitiveType>()); |
1403 | auto expanded_elements = this->expand_exprs(elements); |
1404 | auto mat = |
1405 | Expr(std::make_shared<MatrixExpression>(expanded_elements, shape, dt)); |
1406 | return mat; |
1407 | } |
1408 | |
1409 | Expr ASTBuilder::expr_alloca_shared_array(const std::vector<int> &shape, |
1410 | const DataType &element_type) { |
1411 | auto var = Expr(std::make_shared<IdExpression>(get_next_id())); |
1412 | this->insert(std::make_unique<FrontendAllocaStmt>( |
1413 | std::static_pointer_cast<IdExpression>(var.expr)->id, shape, element_type, |
1414 | true)); |
1415 | var->ret_type = this->get_last_stmt()->ret_type; |
1416 | return var; |
1417 | } |
1418 | |
1419 | void ASTBuilder::expr_assign(const Expr &lhs, const Expr &rhs, std::string tb) { |
1420 | TI_ASSERT(lhs->is_lvalue()); |
1421 | auto stmt = std::make_unique<FrontendAssignStmt>(lhs, rhs); |
1422 | stmt->set_tb(tb); |
1423 | this->insert(std::move(stmt)); |
1424 | } |
1425 | |
1426 | Expr ASTBuilder::expr_subscript(const Expr &expr, |
1427 | const ExprGroup &indices, |
1428 | std::string tb) { |
1429 | TI_ASSERT(expr.is<FieldExpression>() || expr.is<MatrixFieldExpression>() || |
1430 | expr.is<ExternalTensorExpression>() || |
1431 | is_tensor(expr.expr->ret_type)); |
1432 | |
1433 | // IndexExpression without ret_shape is used for matrix indexing, |
1434 | // where each entry of ExprGroup is interpreted as indexing into a specific |
1435 | // axis. For example, mat[3, 4] has indices_group={[3, 4]}, where [3, 4] |
1436 | // corresponds to "n"-axis and "m"-axis of the matrix. Therefore we expand |
1437 | // indices_group={[3, 4]} into {3, 4} to avoid TensorType in indices. |
1438 | std::vector<Expr> expanded_indices = this->expand_exprs(indices.exprs); |
1439 | auto expanded_expr_group = ExprGroup(); |
1440 | expanded_expr_group.exprs = expanded_indices; |
1441 | |
1442 | return Expr::make<IndexExpression>(expr, expanded_expr_group, tb); |
1443 | } |
1444 | |
1445 | void ASTBuilder::create_assert_stmt(const Expr &cond, |
1446 | const std::string &msg, |
1447 | const std::vector<Expr> &args) { |
1448 | auto stmt_unique = std::make_unique<FrontendAssertStmt>(cond, msg, args); |
1449 | this->insert(std::move(stmt_unique)); |
1450 | } |
1451 | |
1452 | void ASTBuilder::begin_frontend_range_for(const Expr &i, |
1453 | const Expr &s, |
1454 | const Expr &e) { |
1455 | auto stmt_unique = |
1456 | std::make_unique<FrontendForStmt>(i, s, e, arch_, for_loop_dec_.config); |
1457 | auto stmt = stmt_unique.get(); |
1458 | this->insert(std::move(stmt_unique)); |
1459 | this->create_scope(stmt->body, |
1460 | for_loop_dec_.config.strictly_serialized ? While : For); |
1461 | for_loop_dec_.reset(); |
1462 | } |
1463 | |
1464 | void ASTBuilder::begin_frontend_struct_for_on_snode(const ExprGroup &loop_vars, |
1465 | SNode *snode) { |
1466 | TI_WARN_IF( |
1467 | for_loop_dec_.config.strictly_serialized, |
1468 | "ti.loop_config(serialize=True) does not have effect on the struct for. " |
1469 | "The execution order is not guaranteed." ); |
1470 | auto stmt_unique = std::make_unique<FrontendForStmt>(loop_vars, snode, arch_, |
1471 | for_loop_dec_.config); |
1472 | for_loop_dec_.reset(); |
1473 | auto stmt = stmt_unique.get(); |
1474 | this->insert(std::move(stmt_unique)); |
1475 | this->create_scope(stmt->body, For); |
1476 | } |
1477 | |
1478 | void ASTBuilder::begin_frontend_struct_for_on_external_tensor( |
1479 | const ExprGroup &loop_vars, |
1480 | const Expr &external_tensor) { |
1481 | TI_WARN_IF( |
1482 | for_loop_dec_.config.strictly_serialized, |
1483 | "ti.loop_config(serialize=True) does not have effect on the struct for. " |
1484 | "The execution order is not guaranteed." ); |
1485 | auto stmt_unique = std::make_unique<FrontendForStmt>( |
1486 | loop_vars, external_tensor, arch_, for_loop_dec_.config); |
1487 | for_loop_dec_.reset(); |
1488 | auto stmt = stmt_unique.get(); |
1489 | this->insert(std::move(stmt_unique)); |
1490 | this->create_scope(stmt->body, For); |
1491 | } |
1492 | |
1493 | void ASTBuilder::begin_frontend_mesh_for( |
1494 | const Expr &i, |
1495 | const mesh::MeshPtr &mesh_ptr, |
1496 | const mesh::MeshElementType &element_type) { |
1497 | TI_WARN_IF( |
1498 | for_loop_dec_.config.strictly_serialized, |
1499 | "ti.loop_config(serialize=True) does not have effect on the mesh for. " |
1500 | "The execution order is not guaranteed." ); |
1501 | auto stmt_unique = std::make_unique<FrontendForStmt>( |
1502 | ExprGroup(i), mesh_ptr, element_type, arch_, for_loop_dec_.config); |
1503 | for_loop_dec_.reset(); |
1504 | auto stmt = stmt_unique.get(); |
1505 | this->insert(std::move(stmt_unique)); |
1506 | this->create_scope(stmt->body, For); |
1507 | } |
1508 | |
1509 | void ASTBuilder::begin_frontend_while(const Expr &cond) { |
1510 | auto stmt_unique = std::make_unique<FrontendWhileStmt>(cond); |
1511 | auto stmt = stmt_unique.get(); |
1512 | this->insert(std::move(stmt_unique)); |
1513 | this->create_scope(stmt->body, While); |
1514 | } |
1515 | |
1516 | void ASTBuilder::insert_break_stmt() { |
1517 | if (loop_state_stack_.back() == Outermost) { |
1518 | throw TaichiSyntaxError("Cannot break in the outermost loop" ); |
1519 | } |
1520 | this->insert(Stmt::make<FrontendBreakStmt>()); |
1521 | } |
1522 | |
1523 | void ASTBuilder::insert_continue_stmt() { |
1524 | this->insert(Stmt::make<FrontendContinueStmt>()); |
1525 | } |
1526 | |
1527 | void ASTBuilder::insert_expr_stmt(const Expr &val) { |
1528 | this->insert(Stmt::make<FrontendExprStmt>(val)); |
1529 | } |
1530 | |
1531 | void ASTBuilder::insert_snode_activate(SNode *snode, |
1532 | const ExprGroup &expr_group) { |
1533 | ExprGroup expanded_group; |
1534 | expanded_group.exprs = this->expand_exprs(expr_group.exprs); |
1535 | this->insert(Stmt::make<FrontendSNodeOpStmt>(SNodeOpType::activate, snode, |
1536 | expanded_group)); |
1537 | } |
1538 | |
1539 | void ASTBuilder::insert_snode_deactivate(SNode *snode, |
1540 | const ExprGroup &expr_group) { |
1541 | ExprGroup expanded_group; |
1542 | expanded_group.exprs = this->expand_exprs(expr_group.exprs); |
1543 | this->insert(Stmt::make<FrontendSNodeOpStmt>(SNodeOpType::deactivate, snode, |
1544 | expanded_group)); |
1545 | } |
1546 | |
1547 | Expr ASTBuilder::snode_append(SNode *snode, |
1548 | const ExprGroup &indices, |
1549 | const std::vector<Expr> &vals) { |
1550 | ExprGroup expanded_exprs; |
1551 | expanded_exprs.exprs = this->expand_exprs(indices.exprs); |
1552 | std::vector<Expr> expanded_vals = this->expand_exprs(vals); |
1553 | return Expr::make<SNodeOpExpression>(snode, SNodeOpType::append, |
1554 | expanded_exprs, expanded_vals); |
1555 | } |
1556 | |
1557 | Expr ASTBuilder::snode_is_active(SNode *snode, const ExprGroup &indices) { |
1558 | ExprGroup expanded_exprs; |
1559 | expanded_exprs.exprs = this->expand_exprs(indices.exprs); |
1560 | return Expr::make<SNodeOpExpression>(snode, SNodeOpType::is_active, |
1561 | expanded_exprs); |
1562 | } |
1563 | |
1564 | Expr ASTBuilder::snode_length(SNode *snode, const ExprGroup &indices) { |
1565 | ExprGroup expanded_exprs; |
1566 | expanded_exprs.exprs = this->expand_exprs(indices.exprs); |
1567 | return Expr::make<SNodeOpExpression>(snode, SNodeOpType::length, |
1568 | expanded_exprs); |
1569 | } |
1570 | |
1571 | Expr ASTBuilder::snode_get_addr(SNode *snode, const ExprGroup &indices) { |
1572 | ExprGroup expanded_exprs; |
1573 | expanded_exprs.exprs = this->expand_exprs(indices.exprs); |
1574 | return Expr::make<SNodeOpExpression>(snode, SNodeOpType::get_addr, |
1575 | expanded_exprs); |
1576 | } |
1577 | |
1578 | std::vector<Expr> ASTBuilder::expand_exprs(const std::vector<Expr> &exprs) { |
1579 | if (exprs.size() == 0) { |
1580 | return exprs; |
1581 | } |
1582 | |
1583 | std::vector<Expr> expanded_exprs; |
1584 | for (auto expr : exprs) { |
1585 | TI_ASSERT_TYPE_CHECKED(expr); |
1586 | if (!expr->ret_type->is<TensorType>()) { |
1587 | expanded_exprs.push_back(expr); |
1588 | } else { |
1589 | // Expand TensorType expr |
1590 | /* |
1591 | Before: |
1592 | TensorType<4 x i32> index = Expr; |
1593 | |
1594 | After: |
1595 | TensorType<4 x i32>* id_expr = FrontendAllocaStmt(TensorType<4 x i32>) |
1596 | i32 ind0 = IndexExpression(id_expr, 0) |
1597 | i32 ind1 = IndexExpression(id_expr, 1) |
1598 | i32 ind2 = IndexExpression(id_expr, 2) |
1599 | i32 ind3 = IndexExpression(id_expr, 3) |
1600 | |
1601 | return {ind0, ind1, ind2, ind3} |
1602 | |
1603 | */ |
1604 | auto tensor_type = expr->ret_type->cast<TensorType>(); |
1605 | |
1606 | Expr id_expr; |
1607 | if (expr.is<IdExpression>()) { |
1608 | id_expr = expr; |
1609 | } else { |
1610 | id_expr = make_var(expr, expr->tb); |
1611 | } |
1612 | auto shape = tensor_type->get_shape(); |
1613 | if (shape.size() == 1) { |
1614 | for (int i = 0; i < shape[0]; i++) { |
1615 | auto ind = Expr(std::make_shared<IndexExpression>( |
1616 | id_expr, ExprGroup(Expr(i)), expr->tb)); |
1617 | ind.expr->ret_type = tensor_type->get_element_type(); |
1618 | expanded_exprs.push_back(ind); |
1619 | } |
1620 | } else { |
1621 | TI_ASSERT(shape.size() == 2); |
1622 | for (int i = 0; i < shape[0]; i++) { |
1623 | for (int j = 0; j < shape[1]; j++) { |
1624 | auto ind = Expr(std::make_shared<IndexExpression>( |
1625 | id_expr, ExprGroup(Expr(i), Expr(j)), expr->tb)); |
1626 | ind.expr->ret_type = tensor_type->get_element_type(); |
1627 | expanded_exprs.push_back(ind); |
1628 | } |
1629 | } |
1630 | } |
1631 | } |
1632 | } |
1633 | |
1634 | return expanded_exprs; |
1635 | } |
1636 | |
1637 | Expr ASTBuilder::mesh_index_conversion(mesh::MeshPtr mesh_ptr, |
1638 | mesh::MeshElementType idx_type, |
1639 | const Expr &idx, |
1640 | mesh::ConvType &conv_type) { |
1641 | Expr expanded_idx; |
1642 | if (idx.is<IdExpression>() && idx.get_ret_type() == PrimitiveType::unknown) { |
1643 | expanded_idx = idx; |
1644 | } else { |
1645 | if (idx.expr->ret_type->is<TensorType>()) { |
1646 | TI_ASSERT(idx.expr->ret_type->cast<TensorType>()->get_num_elements() == |
1647 | 1); |
1648 | } |
1649 | expanded_idx = this->expand_exprs({idx})[0]; |
1650 | } |
1651 | |
1652 | return Expr::make<MeshIndexConversionExpression>(mesh_ptr.ptr.get(), idx_type, |
1653 | expanded_idx, conv_type); |
1654 | } |
1655 | |
1656 | void ASTBuilder::create_scope(std::unique_ptr<Block> &list, LoopType tp) { |
1657 | TI_ASSERT(list == nullptr); |
1658 | LoopState prev = loop_state_stack_.back(); |
1659 | if (tp == NotLoop) { |
1660 | loop_state_stack_.push_back(prev); |
1661 | } else if (tp == For && stack_.size() == 1) { |
1662 | loop_state_stack_.push_back(Outermost); |
1663 | } else { |
1664 | loop_state_stack_.push_back(Inner); |
1665 | } |
1666 | list = std::make_unique<Block>(); |
1667 | if (!stack_.empty()) { |
1668 | list->parent_stmt = get_last_stmt(); |
1669 | } |
1670 | stack_.push_back(list.get()); |
1671 | } |
1672 | |
1673 | void ASTBuilder::pop_scope() { |
1674 | stack_.pop_back(); |
1675 | loop_state_stack_.pop_back(); |
1676 | } |
1677 | |
1678 | Expr ASTBuilder::make_texture_op_expr(const TextureOpType &op, |
1679 | const Expr &texture_ptr, |
1680 | const ExprGroup &args) { |
1681 | ExprGroup expanded_args; |
1682 | expanded_args.exprs = this->expand_exprs(args.exprs); |
1683 | return Expr::make<TextureOpExpression>(op, texture_ptr, expanded_args); |
1684 | } |
1685 | |
1686 | Stmt *flatten_lvalue(Expr expr, Expression::FlattenContext *ctx) { |
1687 | expr->flatten(ctx); |
1688 | return expr->get_flattened_stmt(); |
1689 | } |
1690 | |
1691 | Stmt *flatten_global_load(Stmt *ptr_stmt, Expression::FlattenContext *ctx) { |
1692 | ctx->push_back(std::make_unique<GlobalLoadStmt>(ptr_stmt)); |
1693 | return ctx->back_stmt(); |
1694 | } |
1695 | |
1696 | Stmt *flatten_local_load(Stmt *ptr_stmt, Expression::FlattenContext *ctx) { |
1697 | auto local_load = ctx->push_back<LocalLoadStmt>(ptr_stmt); |
1698 | local_load->ret_type = local_load->src->ret_type.ptr_removed(); |
1699 | return local_load; |
1700 | } |
1701 | |
1702 | Stmt *flatten_rvalue(Expr ptr, Expression::FlattenContext *ctx) { |
1703 | ptr->flatten(ctx); |
1704 | Stmt *ptr_stmt = ptr->get_flattened_stmt(); |
1705 | if (ptr.is<IdExpression>()) { |
1706 | if (ptr_stmt->is<AllocaStmt>()) { |
1707 | return flatten_local_load(ptr_stmt, ctx); |
1708 | } |
1709 | } else if (ptr.is<IndexExpression>()) { |
1710 | auto ix = ptr.cast<IndexExpression>(); |
1711 | if (ix->is_local()) { |
1712 | return flatten_local_load(ptr_stmt, ctx); |
1713 | } else { |
1714 | return flatten_global_load(ptr_stmt, ctx); |
1715 | } |
1716 | } else if (ptr.is<ArgLoadExpression>() && |
1717 | ptr.cast<ArgLoadExpression>()->is_ptr) { |
1718 | return flatten_global_load(ptr_stmt, ctx); |
1719 | } |
1720 | |
1721 | return ptr_stmt; |
1722 | } |
1723 | |
1724 | } // namespace taichi::lang |
1725 | |