1 | #include "taichi/ir/ir.h" |
2 | #include "taichi/ir/statements.h" |
3 | #include "taichi/ir/transforms.h" |
4 | #include "taichi/ir/analysis.h" |
5 | #include "taichi/ir/visitors.h" |
6 | #include "taichi/ir/frontend_ir.h" |
7 | #include "taichi/system/profiler.h" |
8 | |
9 | #include <unordered_set> |
10 | |
11 | namespace taichi::lang { |
12 | |
13 | namespace { |
14 | |
15 | using FlattenContext = Expression::FlattenContext; |
16 | |
17 | template <typename Vec> |
18 | std::vector<typename Vec::value_type::pointer> make_raw_pointer_list( |
19 | const Vec &unique_pointers) { |
20 | std::vector<typename Vec::value_type::pointer> raw_pointers; |
21 | for (auto &ptr : unique_pointers) |
22 | raw_pointers.push_back(ptr.get()); |
23 | return raw_pointers; |
24 | } |
25 | |
26 | } // namespace |
27 | |
28 | // Lower Expr tree to a bunch of binary/unary(binary/unary) statements |
29 | // Goal: eliminate Expression, Identifiers, and mutable local variables. Make |
30 | // AST SSA. |
31 | class LowerAST : public IRVisitor { |
32 | private: |
33 | Stmt *capturing_loop_; |
34 | std::unordered_set<Stmt *> detected_fors_with_break_; |
35 | Block *current_block_; |
36 | int current_block_depth_; |
37 | |
38 | FlattenContext make_flatten_ctx() { |
39 | FlattenContext fctx; |
40 | fctx.current_block = this->current_block_; |
41 | return fctx; |
42 | } |
43 | |
44 | public: |
45 | explicit LowerAST(const std::unordered_set<Stmt *> &_detected_fors_with_break) |
46 | : detected_fors_with_break_(_detected_fors_with_break), |
47 | current_block_(nullptr), |
48 | current_block_depth_(0) { |
49 | // TODO: change this to false |
50 | allow_undefined_visitor = true; |
51 | capturing_loop_ = nullptr; |
52 | } |
53 | |
54 | void visit(Block *stmt_list) override { |
55 | auto backup_block = this->current_block_; |
56 | this->current_block_ = stmt_list; |
57 | auto stmts = make_raw_pointer_list(stmt_list->statements); |
58 | current_block_depth_++; |
59 | for (auto &stmt : stmts) { |
60 | stmt->accept(this); |
61 | } |
62 | current_block_depth_--; |
63 | this->current_block_ = backup_block; |
64 | } |
65 | |
66 | void visit(FrontendAllocaStmt *stmt) override { |
67 | auto block = stmt->parent; |
68 | auto ident = stmt->ident; |
69 | TI_ASSERT(block->local_var_to_stmt.find(ident) == |
70 | block->local_var_to_stmt.end()); |
71 | if (stmt->ret_type->is<TensorType>()) { |
72 | auto tensor_type = stmt->ret_type->cast<TensorType>(); |
73 | auto lowered = std::make_unique<AllocaStmt>( |
74 | tensor_type->get_shape(), tensor_type->get_element_type(), |
75 | stmt->is_shared); |
76 | block->local_var_to_stmt.insert(std::make_pair(ident, lowered.get())); |
77 | stmt->parent->replace_with(stmt, std::move(lowered)); |
78 | } else { |
79 | auto lowered = std::make_unique<AllocaStmt>(stmt->ret_type); |
80 | block->local_var_to_stmt.insert(std::make_pair(ident, lowered.get())); |
81 | stmt->parent->replace_with(stmt, std::move(lowered)); |
82 | } |
83 | } |
84 | |
85 | void visit(FrontendFuncCallStmt *stmt) override { |
86 | Block *block = stmt->parent; |
87 | std::vector<Stmt *> args; |
88 | args.reserve(stmt->args.exprs.size()); |
89 | auto fctx = make_flatten_ctx(); |
90 | for (const auto &arg : stmt->args.exprs) { |
91 | args.push_back(flatten_rvalue(arg, &fctx)); |
92 | } |
93 | auto lowered = fctx.push_back<FuncCallStmt>(stmt->func, args); |
94 | stmt->parent->replace_with(stmt, std::move(fctx.stmts)); |
95 | if (const auto &ident = stmt->ident) { |
96 | TI_ASSERT(block->local_var_to_stmt.find(ident.value()) == |
97 | block->local_var_to_stmt.end()); |
98 | block->local_var_to_stmt.insert(std::make_pair(ident.value(), lowered)); |
99 | } |
100 | } |
101 | |
102 | void visit(FrontendIfStmt *stmt) override { |
103 | auto fctx = make_flatten_ctx(); |
104 | auto condition_stmt = flatten_rvalue(stmt->condition, &fctx); |
105 | |
106 | auto new_if = std::make_unique<IfStmt>(condition_stmt); |
107 | |
108 | if (stmt->true_statements) { |
109 | new_if->set_true_statements(std::move(stmt->true_statements)); |
110 | } |
111 | if (stmt->false_statements) { |
112 | new_if->set_false_statements(std::move(stmt->false_statements)); |
113 | } |
114 | auto pif = new_if.get(); |
115 | fctx.push_back(std::move(new_if)); |
116 | stmt->parent->replace_with(stmt, std::move(fctx.stmts)); |
117 | pif->accept(this); |
118 | } |
119 | |
120 | void visit(IfStmt *if_stmt) override { |
121 | if (if_stmt->true_statements) |
122 | if_stmt->true_statements->accept(this); |
123 | if (if_stmt->false_statements) { |
124 | if_stmt->false_statements->accept(this); |
125 | } |
126 | } |
127 | |
128 | void visit(FrontendPrintStmt *stmt) override { |
129 | // expand rhs |
130 | std::vector<Stmt *> stmts; |
131 | std::vector<std::variant<Stmt *, std::string>> new_contents; |
132 | auto fctx = make_flatten_ctx(); |
133 | for (auto c : stmt->contents) { |
134 | if (std::holds_alternative<Expr>(c)) { |
135 | auto x = std::get<Expr>(c); |
136 | auto x_stmt = flatten_rvalue(x, &fctx); |
137 | stmts.push_back(x_stmt); |
138 | new_contents.push_back(x_stmt); |
139 | } else { |
140 | auto x = std::get<std::string>(c); |
141 | new_contents.push_back(x); |
142 | } |
143 | } |
144 | fctx.push_back<PrintStmt>(new_contents); |
145 | stmt->parent->replace_with(stmt, std::move(fctx.stmts)); |
146 | } |
147 | |
148 | void visit(FrontendBreakStmt *stmt) override { |
149 | auto while_stmt = capturing_loop_->as<WhileStmt>(); |
150 | VecStatement stmts; |
151 | auto const_true = stmts.push_back<ConstStmt>(TypedConstant((int32)0)); |
152 | stmts.push_back<WhileControlStmt>(while_stmt->mask, const_true); |
153 | stmt->parent->replace_with(stmt, std::move(stmts)); |
154 | } |
155 | |
156 | void visit(FrontendContinueStmt *stmt) override { |
157 | stmt->parent->replace_with(stmt, Stmt::make<ContinueStmt>()); |
158 | } |
159 | |
160 | void visit(FrontendWhileStmt *stmt) override { |
161 | // transform into a structure as |
162 | // while (1) { cond; if (no active) break; original body...} |
163 | auto cond = stmt->cond; |
164 | auto fctx = make_flatten_ctx(); |
165 | auto cond_stmt = flatten_rvalue(cond, &fctx); |
166 | |
167 | auto &&new_while = std::make_unique<WhileStmt>(std::move(stmt->body)); |
168 | auto mask = std::make_unique<AllocaStmt>(PrimitiveType::i32); |
169 | new_while->mask = mask.get(); |
170 | auto &stmts = new_while->body; |
171 | stmts->insert(std::move(fctx.stmts), /*location=*/0); |
172 | // insert break |
173 | stmts->insert( |
174 | std::make_unique<WhileControlStmt>(new_while->mask, cond_stmt), |
175 | fctx.stmts.size()); |
176 | auto &&const_stmt = |
177 | std::make_unique<ConstStmt>(TypedConstant((int32)0xFFFFFFFF)); |
178 | auto const_stmt_ptr = const_stmt.get(); |
179 | stmt->insert_before_me(std::move(mask)); |
180 | stmt->insert_before_me(std::move(const_stmt)); |
181 | stmt->insert_before_me( |
182 | std::make_unique<LocalStoreStmt>(new_while->mask, const_stmt_ptr)); |
183 | auto pwhile = new_while.get(); |
184 | stmt->parent->replace_with(stmt, std::move(new_while)); |
185 | pwhile->accept(this); |
186 | // insert an alloca for the mask |
187 | } |
188 | |
189 | void visit(WhileStmt *stmt) override { |
190 | auto old_capturing_loop = capturing_loop_; |
191 | capturing_loop_ = stmt; |
192 | stmt->body->accept(this); |
193 | capturing_loop_ = old_capturing_loop; |
194 | } |
195 | |
196 | void visit(LoopIndexStmt *stmt) override { |
197 | // do nothing |
198 | } |
199 | |
200 | void visit(BinaryOpStmt *stmt) override { |
201 | // do nothing |
202 | } |
203 | |
204 | void visit(FrontendForStmt *stmt) override { |
205 | auto fctx = make_flatten_ctx(); |
206 | if (stmt->snode) { |
207 | auto snode = stmt->snode; |
208 | std::vector<int> offsets; |
209 | if (snode->type == SNodeType::place) { |
210 | /* Note: |
211 | * for i in x: |
212 | * x[i] = 0 |
213 | * |
214 | * has the same effect as |
215 | * |
216 | * for i in x.parent(): |
217 | * x[i] = 0 |
218 | * |
219 | * (unless x has index offsets)*/ |
220 | offsets = snode->index_offsets; |
221 | snode = snode->parent; |
222 | } |
223 | |
224 | // Climb up one more level if inside bit_struct. |
225 | // Note that when looping over bit_structs, we generate |
226 | // struct for on their parent node instead of itself for |
227 | // higher performance. |
228 | if (snode->type == SNodeType::bit_struct) |
229 | snode = snode->parent; |
230 | |
231 | auto &&new_for = std::make_unique<StructForStmt>( |
232 | snode, std::move(stmt->body), stmt->is_bit_vectorized, |
233 | stmt->num_cpu_threads, stmt->block_dim); |
234 | new_for->index_offsets = offsets; |
235 | VecStatement new_statements; |
236 | for (int i = 0; i < (int)stmt->loop_var_ids.size(); i++) { |
237 | Stmt *loop_index = new_statements.push_back<LoopIndexStmt>( |
238 | new_for.get(), snode->physical_index_position[i]); |
239 | if ((int)offsets.size() > i && offsets[i] != 0) { |
240 | auto offset_const = |
241 | new_statements.push_back<ConstStmt>(TypedConstant(offsets[i])); |
242 | auto result = new_statements.push_back<BinaryOpStmt>( |
243 | BinaryOpType::add, loop_index, offset_const); |
244 | loop_index = result; |
245 | } |
246 | new_for->body->local_var_to_stmt[stmt->loop_var_ids[i]] = loop_index; |
247 | } |
248 | new_for->body->insert(std::move(new_statements), 0); |
249 | new_for->mem_access_opt = stmt->mem_access_opt; |
250 | fctx.push_back(std::move(new_for)); |
251 | } else if (stmt->external_tensor) { |
252 | int arg_id = -1; |
253 | std::vector<Stmt *> shape; |
254 | if (stmt->external_tensor.is<ExternalTensorExpression>()) { |
255 | auto tensor = stmt->external_tensor.cast<ExternalTensorExpression>(); |
256 | arg_id = tensor->arg_id; |
257 | for (int i = 0; i < tensor->dim - abs(tensor->element_dim); i++) { |
258 | shape.push_back( |
259 | fctx.push_back<ExternalTensorShapeAlongAxisStmt>(i, arg_id)); |
260 | } |
261 | } else if (stmt->external_tensor.is<TexturePtrExpression>()) { |
262 | auto rw_texture = stmt->external_tensor.cast<TexturePtrExpression>(); |
263 | arg_id = rw_texture->arg_id; |
264 | for (size_t i = 0; i < rw_texture->num_dims; ++i) { |
265 | shape.emplace_back( |
266 | fctx.push_back<ExternalTensorShapeAlongAxisStmt>(i, arg_id)); |
267 | } |
268 | } |
269 | |
270 | Stmt *begin = fctx.push_back<ConstStmt>(TypedConstant(0)); |
271 | Stmt *end = fctx.push_back<ConstStmt>(TypedConstant(1)); |
272 | for (int i = 0; i < (int)shape.size(); i++) { |
273 | end = fctx.push_back<BinaryOpStmt>(BinaryOpType::mul, end, shape[i]); |
274 | } |
275 | // TODO: add a note explaining why shape might be empty. |
276 | auto &&new_for = std::make_unique<RangeForStmt>( |
277 | begin, end, std::move(stmt->body), stmt->is_bit_vectorized, |
278 | stmt->num_cpu_threads, stmt->block_dim, stmt->strictly_serialized, |
279 | /*range_hint=*/fmt::format("arg {}" , arg_id)); |
280 | VecStatement new_statements; |
281 | Stmt *loop_index = |
282 | new_statements.push_back<LoopIndexStmt>(new_for.get(), 0); |
283 | for (int i = (int)shape.size() - 1; i >= 0; i--) { |
284 | Stmt *loop_var = new_statements.push_back<BinaryOpStmt>( |
285 | BinaryOpType::mod, loop_index, shape[i]); |
286 | new_for->body->local_var_to_stmt[stmt->loop_var_ids[i]] = loop_var; |
287 | std::vector<uint32_t> decoration = { |
288 | uint32_t(DecorationStmt::Decoration::kLoopUnique), uint32_t(i)}; |
289 | new_statements.push_back<DecorationStmt>(loop_var, decoration); |
290 | loop_index = new_statements.push_back<BinaryOpStmt>( |
291 | BinaryOpType::div, loop_index, shape[i]); |
292 | } |
293 | new_for->body->insert(std::move(new_statements), 0); |
294 | fctx.push_back(std::move(new_for)); |
295 | } else if (stmt->mesh) { |
296 | auto &&new_for = std::make_unique<MeshForStmt>( |
297 | stmt->mesh, stmt->element_type, std::move(stmt->body), |
298 | stmt->is_bit_vectorized, stmt->num_cpu_threads, stmt->block_dim); |
299 | new_for->body->insert(std::make_unique<LoopIndexStmt>(new_for.get(), 0), |
300 | 0); |
301 | new_for->body->local_var_to_stmt[stmt->loop_var_ids[0]] = |
302 | new_for->body->statements[0].get(); |
303 | new_for->mem_access_opt = stmt->mem_access_opt; |
304 | new_for->fields_registered = true; |
305 | fctx.push_back(std::move(new_for)); |
306 | } else { |
307 | TI_ASSERT(stmt->loop_var_ids.size() == 1); |
308 | auto begin = stmt->begin; |
309 | auto end = stmt->end; |
310 | auto begin_stmt = flatten_rvalue(begin, &fctx); |
311 | auto end_stmt = flatten_rvalue(end, &fctx); |
312 | bool is_good_range_for = detected_fors_with_break_.find(stmt) == |
313 | detected_fors_with_break_.end(); |
314 | // #578: a good range for is a range for that doesn't contain a break |
315 | // statement |
316 | if (is_good_range_for) { |
317 | auto &&new_for = std::make_unique<RangeForStmt>( |
318 | begin_stmt, end_stmt, std::move(stmt->body), |
319 | stmt->is_bit_vectorized, stmt->num_cpu_threads, stmt->block_dim, |
320 | stmt->strictly_serialized); |
321 | new_for->body->insert(std::make_unique<LoopIndexStmt>(new_for.get(), 0), |
322 | 0); |
323 | new_for->body->local_var_to_stmt[stmt->loop_var_ids[0]] = |
324 | new_for->body->statements[0].get(); |
325 | fctx.push_back(std::move(new_for)); |
326 | } else { |
327 | // transform into a structure as |
328 | // i = begin - 1; while (1) { i += 1; if (i >= end) break; original |
329 | // body; } |
330 | fctx.push_back<AllocaStmt>(PrimitiveType::i32); |
331 | auto loop_var = fctx.back_stmt(); |
332 | stmt->parent->local_var_to_stmt[stmt->loop_var_ids[0]] = loop_var; |
333 | auto const_one = fctx.push_back<ConstStmt>(TypedConstant((int32)1)); |
334 | auto begin_minus_one = fctx.push_back<BinaryOpStmt>( |
335 | BinaryOpType::sub, begin_stmt, const_one); |
336 | fctx.push_back<LocalStoreStmt>(loop_var, begin_minus_one); |
337 | auto loop_var_addr = loop_var->as<AllocaStmt>(); |
338 | VecStatement load_and_compare; |
339 | auto loop_var_load_stmt = |
340 | load_and_compare.push_back<LocalLoadStmt>(loop_var_addr); |
341 | auto loop_var_add_one = load_and_compare.push_back<BinaryOpStmt>( |
342 | BinaryOpType::add, loop_var_load_stmt, const_one); |
343 | |
344 | auto cond_stmt = load_and_compare.push_back<BinaryOpStmt>( |
345 | BinaryOpType::cmp_lt, loop_var_add_one, end_stmt); |
346 | |
347 | auto &&new_while = std::make_unique<WhileStmt>(std::move(stmt->body)); |
348 | auto mask = std::make_unique<AllocaStmt>(PrimitiveType::i32); |
349 | new_while->mask = mask.get(); |
350 | |
351 | // insert break |
352 | load_and_compare.push_back<WhileControlStmt>(new_while->mask, |
353 | cond_stmt); |
354 | load_and_compare.push_back<LocalStoreStmt>(loop_var, loop_var_add_one); |
355 | auto &stmts = new_while->body; |
356 | for (int i = 0; i < (int)load_and_compare.size(); i++) { |
357 | stmts->insert(std::move(load_and_compare[i]), i); |
358 | } |
359 | |
360 | stmt->insert_before_me( |
361 | std::make_unique<AllocaStmt>(PrimitiveType::i32)); |
362 | auto &&const_stmt = |
363 | std::make_unique<ConstStmt>(TypedConstant((int32)0xFFFFFFFF)); |
364 | auto const_stmt_ptr = const_stmt.get(); |
365 | stmt->insert_before_me(std::move(mask)); |
366 | stmt->insert_before_me(std::move(const_stmt)); |
367 | stmt->insert_before_me( |
368 | std::make_unique<LocalStoreStmt>(new_while->mask, const_stmt_ptr)); |
369 | fctx.push_back(std::move(new_while)); |
370 | } |
371 | } |
372 | auto pfor = fctx.stmts.back().get(); |
373 | stmt->parent->replace_with(stmt, std::move(fctx.stmts)); |
374 | pfor->accept(this); |
375 | } |
376 | |
377 | void visit(RangeForStmt *for_stmt) override { |
378 | auto old_capturing_loop = capturing_loop_; |
379 | capturing_loop_ = for_stmt; |
380 | for_stmt->body->accept(this); |
381 | capturing_loop_ = old_capturing_loop; |
382 | } |
383 | |
384 | void visit(StructForStmt *for_stmt) override { |
385 | auto old_capturing_loop = capturing_loop_; |
386 | capturing_loop_ = for_stmt; |
387 | for_stmt->body->accept(this); |
388 | capturing_loop_ = old_capturing_loop; |
389 | } |
390 | |
391 | void visit(MeshForStmt *for_stmt) override { |
392 | auto old_capturing_loop = capturing_loop_; |
393 | capturing_loop_ = for_stmt; |
394 | for_stmt->body->accept(this); |
395 | capturing_loop_ = old_capturing_loop; |
396 | } |
397 | |
398 | void visit(FrontendReturnStmt *stmt) override { |
399 | auto expr_group = stmt->values; |
400 | auto fctx = make_flatten_ctx(); |
401 | std::vector<Stmt *> return_ele; |
402 | for (auto &x : expr_group.exprs) { |
403 | return_ele.push_back(flatten_rvalue(x, &fctx)); |
404 | } |
405 | fctx.push_back<ReturnStmt>(return_ele); |
406 | stmt->parent->replace_with(stmt, std::move(fctx.stmts)); |
407 | } |
408 | |
409 | void visit(FrontendAssignStmt *assign) override { |
410 | auto dest = assign->lhs; |
411 | auto expr = assign->rhs; |
412 | auto fctx = make_flatten_ctx(); |
413 | auto expr_stmt = flatten_rvalue(expr, &fctx); |
414 | auto dest_stmt = flatten_lvalue(dest, &fctx); |
415 | if (dest.is<IdExpression>()) { |
416 | fctx.push_back<LocalStoreStmt>(dest_stmt, expr_stmt); |
417 | } else if (dest.is<IndexExpression>()) { |
418 | auto ix = dest.cast<IndexExpression>(); |
419 | if (ix->is_local()) { |
420 | fctx.push_back<LocalStoreStmt>(dest_stmt, expr_stmt); |
421 | } else { |
422 | fctx.push_back<GlobalStoreStmt>(dest_stmt, expr_stmt); |
423 | } |
424 | } else { |
425 | TI_ASSERT(dest.is<ArgLoadExpression>() && |
426 | dest.cast<ArgLoadExpression>()->is_ptr); |
427 | fctx.push_back<GlobalStoreStmt>(dest_stmt, expr_stmt); |
428 | } |
429 | fctx.stmts.back()->set_tb(assign->tb); |
430 | assign->parent->replace_with(assign, std::move(fctx.stmts)); |
431 | } |
432 | |
433 | void visit(FrontendSNodeOpStmt *stmt) override { |
434 | // expand rhs |
435 | Stmt *val_stmt = nullptr; |
436 | auto fctx = make_flatten_ctx(); |
437 | if (stmt->val.expr) { |
438 | val_stmt = flatten_rvalue(stmt->val, &fctx); |
439 | } |
440 | std::vector<Stmt *> indices_stmt(stmt->indices.size(), nullptr); |
441 | |
442 | for (int i = 0; i < (int)stmt->indices.size(); i++) { |
443 | indices_stmt[i] = flatten_rvalue(stmt->indices[i], &fctx); |
444 | } |
445 | |
446 | if (stmt->snode->type == SNodeType::dynamic) { |
447 | auto ptr = fctx.push_back<GlobalPtrStmt>(stmt->snode, indices_stmt); |
448 | fctx.push_back<SNodeOpStmt>(stmt->op_type, stmt->snode, ptr, val_stmt); |
449 | } else if (stmt->snode->type == SNodeType::pointer || |
450 | stmt->snode->type == SNodeType::hash || |
451 | stmt->snode->type == SNodeType::dense || |
452 | stmt->snode->type == SNodeType::bitmasked) { |
453 | TI_ASSERT(SNodeOpStmt::activation_related(stmt->op_type)); |
454 | auto ptr = |
455 | fctx.push_back<GlobalPtrStmt>(stmt->snode, indices_stmt, true, true); |
456 | fctx.push_back<SNodeOpStmt>(stmt->op_type, stmt->snode, ptr, val_stmt); |
457 | } else { |
458 | TI_ERROR("The {} operation is not supported on {} SNode" , |
459 | snode_op_type_name(stmt->op_type), |
460 | snode_type_name(stmt->snode->type)); |
461 | TI_NOT_IMPLEMENTED |
462 | } |
463 | |
464 | stmt->parent->replace_with(stmt, std::move(fctx.stmts)); |
465 | } |
466 | |
467 | void visit(FrontendAssertStmt *stmt) override { |
468 | // expand rhs |
469 | Stmt *val_stmt = nullptr; |
470 | auto fctx = make_flatten_ctx(); |
471 | if (stmt->cond.expr) { |
472 | val_stmt = flatten_rvalue(stmt->cond, &fctx); |
473 | } |
474 | |
475 | auto &fargs = stmt->args; // frontend stmt args |
476 | std::vector<Stmt *> args_stmts(fargs.size()); |
477 | for (int i = 0; i < (int)fargs.size(); ++i) { |
478 | args_stmts[i] = flatten_rvalue(fargs[i], &fctx); |
479 | } |
480 | fctx.push_back<AssertStmt>(val_stmt, stmt->text, args_stmts); |
481 | stmt->parent->replace_with(stmt, std::move(fctx.stmts)); |
482 | } |
483 | |
484 | void visit(FrontendExprStmt *stmt) override { |
485 | auto fctx = make_flatten_ctx(); |
486 | flatten_rvalue(stmt->val, &fctx); |
487 | stmt->parent->replace_with(stmt, std::move(fctx.stmts)); |
488 | } |
489 | |
490 | void visit(FrontendExternalFuncStmt *stmt) override { |
491 | auto ctx = make_flatten_ctx(); |
492 | TI_ASSERT((int)(stmt->so_func != nullptr) + |
493 | (int)(!stmt->asm_source.empty()) + |
494 | (int)(!stmt->bc_filename.empty()) == |
495 | 1); |
496 | std::vector<Stmt *> arg_statements, output_statements; |
497 | if (stmt->so_func != nullptr || !stmt->asm_source.empty()) { |
498 | for (auto &s : stmt->args) { |
499 | arg_statements.push_back(flatten_rvalue(s, &ctx)); |
500 | } |
501 | for (auto &s : stmt->outputs) { |
502 | output_statements.push_back(flatten_lvalue(s, &ctx)); |
503 | } |
504 | ctx.push_back(std::make_unique<ExternalFuncCallStmt>( |
505 | (stmt->so_func != nullptr) ? ExternalFuncCallStmt::SHARED_OBJECT |
506 | : ExternalFuncCallStmt::ASSEMBLY, |
507 | stmt->so_func, stmt->asm_source, "" , "" , arg_statements, |
508 | output_statements)); |
509 | } else { |
510 | for (auto &s : stmt->args) { |
511 | TI_ASSERT_INFO( |
512 | s.is<IdExpression>(), |
513 | "external func call via bitcode must pass in local variables." ) |
514 | arg_statements.push_back(flatten_lvalue(s, &ctx)); |
515 | } |
516 | ctx.push_back(std::make_unique<ExternalFuncCallStmt>( |
517 | ExternalFuncCallStmt::BITCODE, nullptr, "" , stmt->bc_filename, |
518 | stmt->bc_funcname, arg_statements, output_statements)); |
519 | } |
520 | stmt->parent->replace_with(stmt, std::move(ctx.stmts)); |
521 | } |
522 | |
523 | static void run(IRNode *node) { |
524 | LowerAST inst(irpass::analysis::detect_fors_with_break(node)); |
525 | node->accept(&inst); |
526 | } |
527 | }; |
528 | |
529 | namespace irpass { |
530 | |
531 | void lower_ast(IRNode *root) { |
532 | TI_AUTO_PROF; |
533 | LowerAST::run(root); |
534 | } |
535 | |
536 | } // namespace irpass |
537 | |
538 | } // namespace taichi::lang |
539 | |