1#include "taichi/ir/ir.h"
2#include "taichi/ir/statements.h"
3#include "taichi/ir/transforms.h"
4#include "taichi/ir/analysis.h"
5#include "taichi/ir/visitors.h"
6#include "taichi/ir/frontend_ir.h"
7#include "taichi/system/profiler.h"
8
9#include <unordered_set>
10
11namespace taichi::lang {
12
13namespace {
14
15using FlattenContext = Expression::FlattenContext;
16
17template <typename Vec>
18std::vector<typename Vec::value_type::pointer> make_raw_pointer_list(
19 const Vec &unique_pointers) {
20 std::vector<typename Vec::value_type::pointer> raw_pointers;
21 for (auto &ptr : unique_pointers)
22 raw_pointers.push_back(ptr.get());
23 return raw_pointers;
24}
25
26} // namespace
27
28// Lower Expr tree to a bunch of binary/unary(binary/unary) statements
29// Goal: eliminate Expression, Identifiers, and mutable local variables. Make
30// AST SSA.
31class LowerAST : public IRVisitor {
32 private:
33 Stmt *capturing_loop_;
34 std::unordered_set<Stmt *> detected_fors_with_break_;
35 Block *current_block_;
36 int current_block_depth_;
37
38 FlattenContext make_flatten_ctx() {
39 FlattenContext fctx;
40 fctx.current_block = this->current_block_;
41 return fctx;
42 }
43
44 public:
45 explicit LowerAST(const std::unordered_set<Stmt *> &_detected_fors_with_break)
46 : detected_fors_with_break_(_detected_fors_with_break),
47 current_block_(nullptr),
48 current_block_depth_(0) {
49 // TODO: change this to false
50 allow_undefined_visitor = true;
51 capturing_loop_ = nullptr;
52 }
53
54 void visit(Block *stmt_list) override {
55 auto backup_block = this->current_block_;
56 this->current_block_ = stmt_list;
57 auto stmts = make_raw_pointer_list(stmt_list->statements);
58 current_block_depth_++;
59 for (auto &stmt : stmts) {
60 stmt->accept(this);
61 }
62 current_block_depth_--;
63 this->current_block_ = backup_block;
64 }
65
66 void visit(FrontendAllocaStmt *stmt) override {
67 auto block = stmt->parent;
68 auto ident = stmt->ident;
69 TI_ASSERT(block->local_var_to_stmt.find(ident) ==
70 block->local_var_to_stmt.end());
71 if (stmt->ret_type->is<TensorType>()) {
72 auto tensor_type = stmt->ret_type->cast<TensorType>();
73 auto lowered = std::make_unique<AllocaStmt>(
74 tensor_type->get_shape(), tensor_type->get_element_type(),
75 stmt->is_shared);
76 block->local_var_to_stmt.insert(std::make_pair(ident, lowered.get()));
77 stmt->parent->replace_with(stmt, std::move(lowered));
78 } else {
79 auto lowered = std::make_unique<AllocaStmt>(stmt->ret_type);
80 block->local_var_to_stmt.insert(std::make_pair(ident, lowered.get()));
81 stmt->parent->replace_with(stmt, std::move(lowered));
82 }
83 }
84
85 void visit(FrontendFuncCallStmt *stmt) override {
86 Block *block = stmt->parent;
87 std::vector<Stmt *> args;
88 args.reserve(stmt->args.exprs.size());
89 auto fctx = make_flatten_ctx();
90 for (const auto &arg : stmt->args.exprs) {
91 args.push_back(flatten_rvalue(arg, &fctx));
92 }
93 auto lowered = fctx.push_back<FuncCallStmt>(stmt->func, args);
94 stmt->parent->replace_with(stmt, std::move(fctx.stmts));
95 if (const auto &ident = stmt->ident) {
96 TI_ASSERT(block->local_var_to_stmt.find(ident.value()) ==
97 block->local_var_to_stmt.end());
98 block->local_var_to_stmt.insert(std::make_pair(ident.value(), lowered));
99 }
100 }
101
102 void visit(FrontendIfStmt *stmt) override {
103 auto fctx = make_flatten_ctx();
104 auto condition_stmt = flatten_rvalue(stmt->condition, &fctx);
105
106 auto new_if = std::make_unique<IfStmt>(condition_stmt);
107
108 if (stmt->true_statements) {
109 new_if->set_true_statements(std::move(stmt->true_statements));
110 }
111 if (stmt->false_statements) {
112 new_if->set_false_statements(std::move(stmt->false_statements));
113 }
114 auto pif = new_if.get();
115 fctx.push_back(std::move(new_if));
116 stmt->parent->replace_with(stmt, std::move(fctx.stmts));
117 pif->accept(this);
118 }
119
120 void visit(IfStmt *if_stmt) override {
121 if (if_stmt->true_statements)
122 if_stmt->true_statements->accept(this);
123 if (if_stmt->false_statements) {
124 if_stmt->false_statements->accept(this);
125 }
126 }
127
128 void visit(FrontendPrintStmt *stmt) override {
129 // expand rhs
130 std::vector<Stmt *> stmts;
131 std::vector<std::variant<Stmt *, std::string>> new_contents;
132 auto fctx = make_flatten_ctx();
133 for (auto c : stmt->contents) {
134 if (std::holds_alternative<Expr>(c)) {
135 auto x = std::get<Expr>(c);
136 auto x_stmt = flatten_rvalue(x, &fctx);
137 stmts.push_back(x_stmt);
138 new_contents.push_back(x_stmt);
139 } else {
140 auto x = std::get<std::string>(c);
141 new_contents.push_back(x);
142 }
143 }
144 fctx.push_back<PrintStmt>(new_contents);
145 stmt->parent->replace_with(stmt, std::move(fctx.stmts));
146 }
147
148 void visit(FrontendBreakStmt *stmt) override {
149 auto while_stmt = capturing_loop_->as<WhileStmt>();
150 VecStatement stmts;
151 auto const_true = stmts.push_back<ConstStmt>(TypedConstant((int32)0));
152 stmts.push_back<WhileControlStmt>(while_stmt->mask, const_true);
153 stmt->parent->replace_with(stmt, std::move(stmts));
154 }
155
156 void visit(FrontendContinueStmt *stmt) override {
157 stmt->parent->replace_with(stmt, Stmt::make<ContinueStmt>());
158 }
159
160 void visit(FrontendWhileStmt *stmt) override {
161 // transform into a structure as
162 // while (1) { cond; if (no active) break; original body...}
163 auto cond = stmt->cond;
164 auto fctx = make_flatten_ctx();
165 auto cond_stmt = flatten_rvalue(cond, &fctx);
166
167 auto &&new_while = std::make_unique<WhileStmt>(std::move(stmt->body));
168 auto mask = std::make_unique<AllocaStmt>(PrimitiveType::i32);
169 new_while->mask = mask.get();
170 auto &stmts = new_while->body;
171 stmts->insert(std::move(fctx.stmts), /*location=*/0);
172 // insert break
173 stmts->insert(
174 std::make_unique<WhileControlStmt>(new_while->mask, cond_stmt),
175 fctx.stmts.size());
176 auto &&const_stmt =
177 std::make_unique<ConstStmt>(TypedConstant((int32)0xFFFFFFFF));
178 auto const_stmt_ptr = const_stmt.get();
179 stmt->insert_before_me(std::move(mask));
180 stmt->insert_before_me(std::move(const_stmt));
181 stmt->insert_before_me(
182 std::make_unique<LocalStoreStmt>(new_while->mask, const_stmt_ptr));
183 auto pwhile = new_while.get();
184 stmt->parent->replace_with(stmt, std::move(new_while));
185 pwhile->accept(this);
186 // insert an alloca for the mask
187 }
188
189 void visit(WhileStmt *stmt) override {
190 auto old_capturing_loop = capturing_loop_;
191 capturing_loop_ = stmt;
192 stmt->body->accept(this);
193 capturing_loop_ = old_capturing_loop;
194 }
195
196 void visit(LoopIndexStmt *stmt) override {
197 // do nothing
198 }
199
200 void visit(BinaryOpStmt *stmt) override {
201 // do nothing
202 }
203
204 void visit(FrontendForStmt *stmt) override {
205 auto fctx = make_flatten_ctx();
206 if (stmt->snode) {
207 auto snode = stmt->snode;
208 std::vector<int> offsets;
209 if (snode->type == SNodeType::place) {
210 /* Note:
211 * for i in x:
212 * x[i] = 0
213 *
214 * has the same effect as
215 *
216 * for i in x.parent():
217 * x[i] = 0
218 *
219 * (unless x has index offsets)*/
220 offsets = snode->index_offsets;
221 snode = snode->parent;
222 }
223
224 // Climb up one more level if inside bit_struct.
225 // Note that when looping over bit_structs, we generate
226 // struct for on their parent node instead of itself for
227 // higher performance.
228 if (snode->type == SNodeType::bit_struct)
229 snode = snode->parent;
230
231 auto &&new_for = std::make_unique<StructForStmt>(
232 snode, std::move(stmt->body), stmt->is_bit_vectorized,
233 stmt->num_cpu_threads, stmt->block_dim);
234 new_for->index_offsets = offsets;
235 VecStatement new_statements;
236 for (int i = 0; i < (int)stmt->loop_var_ids.size(); i++) {
237 Stmt *loop_index = new_statements.push_back<LoopIndexStmt>(
238 new_for.get(), snode->physical_index_position[i]);
239 if ((int)offsets.size() > i && offsets[i] != 0) {
240 auto offset_const =
241 new_statements.push_back<ConstStmt>(TypedConstant(offsets[i]));
242 auto result = new_statements.push_back<BinaryOpStmt>(
243 BinaryOpType::add, loop_index, offset_const);
244 loop_index = result;
245 }
246 new_for->body->local_var_to_stmt[stmt->loop_var_ids[i]] = loop_index;
247 }
248 new_for->body->insert(std::move(new_statements), 0);
249 new_for->mem_access_opt = stmt->mem_access_opt;
250 fctx.push_back(std::move(new_for));
251 } else if (stmt->external_tensor) {
252 int arg_id = -1;
253 std::vector<Stmt *> shape;
254 if (stmt->external_tensor.is<ExternalTensorExpression>()) {
255 auto tensor = stmt->external_tensor.cast<ExternalTensorExpression>();
256 arg_id = tensor->arg_id;
257 for (int i = 0; i < tensor->dim - abs(tensor->element_dim); i++) {
258 shape.push_back(
259 fctx.push_back<ExternalTensorShapeAlongAxisStmt>(i, arg_id));
260 }
261 } else if (stmt->external_tensor.is<TexturePtrExpression>()) {
262 auto rw_texture = stmt->external_tensor.cast<TexturePtrExpression>();
263 arg_id = rw_texture->arg_id;
264 for (size_t i = 0; i < rw_texture->num_dims; ++i) {
265 shape.emplace_back(
266 fctx.push_back<ExternalTensorShapeAlongAxisStmt>(i, arg_id));
267 }
268 }
269
270 Stmt *begin = fctx.push_back<ConstStmt>(TypedConstant(0));
271 Stmt *end = fctx.push_back<ConstStmt>(TypedConstant(1));
272 for (int i = 0; i < (int)shape.size(); i++) {
273 end = fctx.push_back<BinaryOpStmt>(BinaryOpType::mul, end, shape[i]);
274 }
275 // TODO: add a note explaining why shape might be empty.
276 auto &&new_for = std::make_unique<RangeForStmt>(
277 begin, end, std::move(stmt->body), stmt->is_bit_vectorized,
278 stmt->num_cpu_threads, stmt->block_dim, stmt->strictly_serialized,
279 /*range_hint=*/fmt::format("arg {}", arg_id));
280 VecStatement new_statements;
281 Stmt *loop_index =
282 new_statements.push_back<LoopIndexStmt>(new_for.get(), 0);
283 for (int i = (int)shape.size() - 1; i >= 0; i--) {
284 Stmt *loop_var = new_statements.push_back<BinaryOpStmt>(
285 BinaryOpType::mod, loop_index, shape[i]);
286 new_for->body->local_var_to_stmt[stmt->loop_var_ids[i]] = loop_var;
287 std::vector<uint32_t> decoration = {
288 uint32_t(DecorationStmt::Decoration::kLoopUnique), uint32_t(i)};
289 new_statements.push_back<DecorationStmt>(loop_var, decoration);
290 loop_index = new_statements.push_back<BinaryOpStmt>(
291 BinaryOpType::div, loop_index, shape[i]);
292 }
293 new_for->body->insert(std::move(new_statements), 0);
294 fctx.push_back(std::move(new_for));
295 } else if (stmt->mesh) {
296 auto &&new_for = std::make_unique<MeshForStmt>(
297 stmt->mesh, stmt->element_type, std::move(stmt->body),
298 stmt->is_bit_vectorized, stmt->num_cpu_threads, stmt->block_dim);
299 new_for->body->insert(std::make_unique<LoopIndexStmt>(new_for.get(), 0),
300 0);
301 new_for->body->local_var_to_stmt[stmt->loop_var_ids[0]] =
302 new_for->body->statements[0].get();
303 new_for->mem_access_opt = stmt->mem_access_opt;
304 new_for->fields_registered = true;
305 fctx.push_back(std::move(new_for));
306 } else {
307 TI_ASSERT(stmt->loop_var_ids.size() == 1);
308 auto begin = stmt->begin;
309 auto end = stmt->end;
310 auto begin_stmt = flatten_rvalue(begin, &fctx);
311 auto end_stmt = flatten_rvalue(end, &fctx);
312 bool is_good_range_for = detected_fors_with_break_.find(stmt) ==
313 detected_fors_with_break_.end();
314 // #578: a good range for is a range for that doesn't contain a break
315 // statement
316 if (is_good_range_for) {
317 auto &&new_for = std::make_unique<RangeForStmt>(
318 begin_stmt, end_stmt, std::move(stmt->body),
319 stmt->is_bit_vectorized, stmt->num_cpu_threads, stmt->block_dim,
320 stmt->strictly_serialized);
321 new_for->body->insert(std::make_unique<LoopIndexStmt>(new_for.get(), 0),
322 0);
323 new_for->body->local_var_to_stmt[stmt->loop_var_ids[0]] =
324 new_for->body->statements[0].get();
325 fctx.push_back(std::move(new_for));
326 } else {
327 // transform into a structure as
328 // i = begin - 1; while (1) { i += 1; if (i >= end) break; original
329 // body; }
330 fctx.push_back<AllocaStmt>(PrimitiveType::i32);
331 auto loop_var = fctx.back_stmt();
332 stmt->parent->local_var_to_stmt[stmt->loop_var_ids[0]] = loop_var;
333 auto const_one = fctx.push_back<ConstStmt>(TypedConstant((int32)1));
334 auto begin_minus_one = fctx.push_back<BinaryOpStmt>(
335 BinaryOpType::sub, begin_stmt, const_one);
336 fctx.push_back<LocalStoreStmt>(loop_var, begin_minus_one);
337 auto loop_var_addr = loop_var->as<AllocaStmt>();
338 VecStatement load_and_compare;
339 auto loop_var_load_stmt =
340 load_and_compare.push_back<LocalLoadStmt>(loop_var_addr);
341 auto loop_var_add_one = load_and_compare.push_back<BinaryOpStmt>(
342 BinaryOpType::add, loop_var_load_stmt, const_one);
343
344 auto cond_stmt = load_and_compare.push_back<BinaryOpStmt>(
345 BinaryOpType::cmp_lt, loop_var_add_one, end_stmt);
346
347 auto &&new_while = std::make_unique<WhileStmt>(std::move(stmt->body));
348 auto mask = std::make_unique<AllocaStmt>(PrimitiveType::i32);
349 new_while->mask = mask.get();
350
351 // insert break
352 load_and_compare.push_back<WhileControlStmt>(new_while->mask,
353 cond_stmt);
354 load_and_compare.push_back<LocalStoreStmt>(loop_var, loop_var_add_one);
355 auto &stmts = new_while->body;
356 for (int i = 0; i < (int)load_and_compare.size(); i++) {
357 stmts->insert(std::move(load_and_compare[i]), i);
358 }
359
360 stmt->insert_before_me(
361 std::make_unique<AllocaStmt>(PrimitiveType::i32));
362 auto &&const_stmt =
363 std::make_unique<ConstStmt>(TypedConstant((int32)0xFFFFFFFF));
364 auto const_stmt_ptr = const_stmt.get();
365 stmt->insert_before_me(std::move(mask));
366 stmt->insert_before_me(std::move(const_stmt));
367 stmt->insert_before_me(
368 std::make_unique<LocalStoreStmt>(new_while->mask, const_stmt_ptr));
369 fctx.push_back(std::move(new_while));
370 }
371 }
372 auto pfor = fctx.stmts.back().get();
373 stmt->parent->replace_with(stmt, std::move(fctx.stmts));
374 pfor->accept(this);
375 }
376
377 void visit(RangeForStmt *for_stmt) override {
378 auto old_capturing_loop = capturing_loop_;
379 capturing_loop_ = for_stmt;
380 for_stmt->body->accept(this);
381 capturing_loop_ = old_capturing_loop;
382 }
383
384 void visit(StructForStmt *for_stmt) override {
385 auto old_capturing_loop = capturing_loop_;
386 capturing_loop_ = for_stmt;
387 for_stmt->body->accept(this);
388 capturing_loop_ = old_capturing_loop;
389 }
390
391 void visit(MeshForStmt *for_stmt) override {
392 auto old_capturing_loop = capturing_loop_;
393 capturing_loop_ = for_stmt;
394 for_stmt->body->accept(this);
395 capturing_loop_ = old_capturing_loop;
396 }
397
398 void visit(FrontendReturnStmt *stmt) override {
399 auto expr_group = stmt->values;
400 auto fctx = make_flatten_ctx();
401 std::vector<Stmt *> return_ele;
402 for (auto &x : expr_group.exprs) {
403 return_ele.push_back(flatten_rvalue(x, &fctx));
404 }
405 fctx.push_back<ReturnStmt>(return_ele);
406 stmt->parent->replace_with(stmt, std::move(fctx.stmts));
407 }
408
409 void visit(FrontendAssignStmt *assign) override {
410 auto dest = assign->lhs;
411 auto expr = assign->rhs;
412 auto fctx = make_flatten_ctx();
413 auto expr_stmt = flatten_rvalue(expr, &fctx);
414 auto dest_stmt = flatten_lvalue(dest, &fctx);
415 if (dest.is<IdExpression>()) {
416 fctx.push_back<LocalStoreStmt>(dest_stmt, expr_stmt);
417 } else if (dest.is<IndexExpression>()) {
418 auto ix = dest.cast<IndexExpression>();
419 if (ix->is_local()) {
420 fctx.push_back<LocalStoreStmt>(dest_stmt, expr_stmt);
421 } else {
422 fctx.push_back<GlobalStoreStmt>(dest_stmt, expr_stmt);
423 }
424 } else {
425 TI_ASSERT(dest.is<ArgLoadExpression>() &&
426 dest.cast<ArgLoadExpression>()->is_ptr);
427 fctx.push_back<GlobalStoreStmt>(dest_stmt, expr_stmt);
428 }
429 fctx.stmts.back()->set_tb(assign->tb);
430 assign->parent->replace_with(assign, std::move(fctx.stmts));
431 }
432
433 void visit(FrontendSNodeOpStmt *stmt) override {
434 // expand rhs
435 Stmt *val_stmt = nullptr;
436 auto fctx = make_flatten_ctx();
437 if (stmt->val.expr) {
438 val_stmt = flatten_rvalue(stmt->val, &fctx);
439 }
440 std::vector<Stmt *> indices_stmt(stmt->indices.size(), nullptr);
441
442 for (int i = 0; i < (int)stmt->indices.size(); i++) {
443 indices_stmt[i] = flatten_rvalue(stmt->indices[i], &fctx);
444 }
445
446 if (stmt->snode->type == SNodeType::dynamic) {
447 auto ptr = fctx.push_back<GlobalPtrStmt>(stmt->snode, indices_stmt);
448 fctx.push_back<SNodeOpStmt>(stmt->op_type, stmt->snode, ptr, val_stmt);
449 } else if (stmt->snode->type == SNodeType::pointer ||
450 stmt->snode->type == SNodeType::hash ||
451 stmt->snode->type == SNodeType::dense ||
452 stmt->snode->type == SNodeType::bitmasked) {
453 TI_ASSERT(SNodeOpStmt::activation_related(stmt->op_type));
454 auto ptr =
455 fctx.push_back<GlobalPtrStmt>(stmt->snode, indices_stmt, true, true);
456 fctx.push_back<SNodeOpStmt>(stmt->op_type, stmt->snode, ptr, val_stmt);
457 } else {
458 TI_ERROR("The {} operation is not supported on {} SNode",
459 snode_op_type_name(stmt->op_type),
460 snode_type_name(stmt->snode->type));
461 TI_NOT_IMPLEMENTED
462 }
463
464 stmt->parent->replace_with(stmt, std::move(fctx.stmts));
465 }
466
467 void visit(FrontendAssertStmt *stmt) override {
468 // expand rhs
469 Stmt *val_stmt = nullptr;
470 auto fctx = make_flatten_ctx();
471 if (stmt->cond.expr) {
472 val_stmt = flatten_rvalue(stmt->cond, &fctx);
473 }
474
475 auto &fargs = stmt->args; // frontend stmt args
476 std::vector<Stmt *> args_stmts(fargs.size());
477 for (int i = 0; i < (int)fargs.size(); ++i) {
478 args_stmts[i] = flatten_rvalue(fargs[i], &fctx);
479 }
480 fctx.push_back<AssertStmt>(val_stmt, stmt->text, args_stmts);
481 stmt->parent->replace_with(stmt, std::move(fctx.stmts));
482 }
483
484 void visit(FrontendExprStmt *stmt) override {
485 auto fctx = make_flatten_ctx();
486 flatten_rvalue(stmt->val, &fctx);
487 stmt->parent->replace_with(stmt, std::move(fctx.stmts));
488 }
489
490 void visit(FrontendExternalFuncStmt *stmt) override {
491 auto ctx = make_flatten_ctx();
492 TI_ASSERT((int)(stmt->so_func != nullptr) +
493 (int)(!stmt->asm_source.empty()) +
494 (int)(!stmt->bc_filename.empty()) ==
495 1);
496 std::vector<Stmt *> arg_statements, output_statements;
497 if (stmt->so_func != nullptr || !stmt->asm_source.empty()) {
498 for (auto &s : stmt->args) {
499 arg_statements.push_back(flatten_rvalue(s, &ctx));
500 }
501 for (auto &s : stmt->outputs) {
502 output_statements.push_back(flatten_lvalue(s, &ctx));
503 }
504 ctx.push_back(std::make_unique<ExternalFuncCallStmt>(
505 (stmt->so_func != nullptr) ? ExternalFuncCallStmt::SHARED_OBJECT
506 : ExternalFuncCallStmt::ASSEMBLY,
507 stmt->so_func, stmt->asm_source, "", "", arg_statements,
508 output_statements));
509 } else {
510 for (auto &s : stmt->args) {
511 TI_ASSERT_INFO(
512 s.is<IdExpression>(),
513 "external func call via bitcode must pass in local variables.")
514 arg_statements.push_back(flatten_lvalue(s, &ctx));
515 }
516 ctx.push_back(std::make_unique<ExternalFuncCallStmt>(
517 ExternalFuncCallStmt::BITCODE, nullptr, "", stmt->bc_filename,
518 stmt->bc_funcname, arg_statements, output_statements));
519 }
520 stmt->parent->replace_with(stmt, std::move(ctx.stmts));
521 }
522
523 static void run(IRNode *node) {
524 LowerAST inst(irpass::analysis::detect_fors_with_break(node));
525 node->accept(&inst);
526 }
527};
528
529namespace irpass {
530
531void lower_ast(IRNode *root) {
532 TI_AUTO_PROF;
533 LowerAST::run(root);
534}
535
536} // namespace irpass
537
538} // namespace taichi::lang
539