1 | #include "taichi/transforms/inlining.h" |
2 | #include "taichi/ir/analysis.h" |
3 | #include "taichi/ir/ir.h" |
4 | #include "taichi/ir/statements.h" |
5 | #include "taichi/ir/transforms.h" |
6 | #include "taichi/ir/visitors.h" |
7 | #include "taichi/program/program.h" |
8 | |
9 | namespace taichi::lang { |
10 | |
11 | // Inline all functions. |
12 | class Inliner : public BasicStmtVisitor { |
13 | public: |
14 | using BasicStmtVisitor::visit; |
15 | |
16 | explicit Inliner() { |
17 | } |
18 | |
19 | void visit(FuncCallStmt *stmt) override { |
20 | auto *func = stmt->func; |
21 | TI_ASSERT(func); |
22 | TI_ASSERT(func->parameter_list.size() == stmt->args.size()); |
23 | TI_ASSERT(func->ir->is<Block>()); |
24 | TI_ASSERT(func->rets.size() <= 1); |
25 | auto inlined_ir = irpass::analysis::clone(func->ir.get()); |
26 | if (!func->parameter_list.empty()) { |
27 | irpass::replace_statements( |
28 | inlined_ir.get(), |
29 | /*filter=*/[&](Stmt *s) { return s->is<ArgLoadStmt>(); }, |
30 | /*finder=*/ |
31 | [&](Stmt *s) { return stmt->args[s->as<ArgLoadStmt>()->arg_id]; }); |
32 | } |
33 | if (func->rets.empty()) { |
34 | modifier_.replace_with( |
35 | stmt, VecStatement(std::move(inlined_ir->as<Block>()->statements))); |
36 | } else { |
37 | if (irpass::analysis::gather_statements(inlined_ir.get(), [&](Stmt *s) { |
38 | return s->is<ReturnStmt>(); |
39 | }).size() > 1) { |
40 | TI_WARN( |
41 | "Multiple returns in function \"{}\" may not be handled " |
42 | "properly.\n{}" , |
43 | func->get_name(), stmt->tb); |
44 | } |
45 | // Use a local variable to store the return value |
46 | auto *return_address = inlined_ir->as<Block>()->insert( |
47 | Stmt::make<AllocaStmt>(func->rets[0].dt), /*location=*/0); |
48 | irpass::replace_and_insert_statements( |
49 | inlined_ir.get(), |
50 | /*filter=*/[&](Stmt *s) { return s->is<ReturnStmt>(); }, |
51 | /*generator=*/ |
52 | [&](Stmt *s) { |
53 | TI_ASSERT(s->as<ReturnStmt>()->values.size() == 1); |
54 | return Stmt::make<LocalStoreStmt>(return_address, |
55 | s->as<ReturnStmt>()->values[0]); |
56 | }); |
57 | modifier_.insert_before(stmt, |
58 | std::move(inlined_ir->as<Block>()->statements)); |
59 | // Load the return value here |
60 | modifier_.replace_with(stmt, Stmt::make<LocalLoadStmt>(return_address)); |
61 | } |
62 | } |
63 | |
64 | static bool run(IRNode *node) { |
65 | Inliner inliner; |
66 | bool modified = false; |
67 | while (true) { |
68 | node->accept(&inliner); |
69 | if (inliner.modifier_.modify_ir()) |
70 | modified = true; |
71 | else |
72 | break; |
73 | } |
74 | return modified; |
75 | } |
76 | |
77 | private: |
78 | DelayedIRModifier modifier_; |
79 | }; |
80 | |
81 | const PassID InliningPass::id = "InliningPass" ; |
82 | |
83 | namespace irpass { |
84 | |
85 | bool inlining(IRNode *root, |
86 | const CompileConfig &config, |
87 | const InliningPass::Args &args) { |
88 | TI_AUTO_PROF; |
89 | return Inliner::run(root); |
90 | } |
91 | |
92 | } // namespace irpass |
93 | |
94 | } // namespace taichi::lang |
95 | |