1#include "taichi/transforms/inlining.h"
2#include "taichi/ir/analysis.h"
3#include "taichi/ir/ir.h"
4#include "taichi/ir/statements.h"
5#include "taichi/ir/transforms.h"
6#include "taichi/ir/visitors.h"
7#include "taichi/program/program.h"
8
9namespace taichi::lang {
10
11// Inline all functions.
12class Inliner : public BasicStmtVisitor {
13 public:
14 using BasicStmtVisitor::visit;
15
16 explicit Inliner() {
17 }
18
19 void visit(FuncCallStmt *stmt) override {
20 auto *func = stmt->func;
21 TI_ASSERT(func);
22 TI_ASSERT(func->parameter_list.size() == stmt->args.size());
23 TI_ASSERT(func->ir->is<Block>());
24 TI_ASSERT(func->rets.size() <= 1);
25 auto inlined_ir = irpass::analysis::clone(func->ir.get());
26 if (!func->parameter_list.empty()) {
27 irpass::replace_statements(
28 inlined_ir.get(),
29 /*filter=*/[&](Stmt *s) { return s->is<ArgLoadStmt>(); },
30 /*finder=*/
31 [&](Stmt *s) { return stmt->args[s->as<ArgLoadStmt>()->arg_id]; });
32 }
33 if (func->rets.empty()) {
34 modifier_.replace_with(
35 stmt, VecStatement(std::move(inlined_ir->as<Block>()->statements)));
36 } else {
37 if (irpass::analysis::gather_statements(inlined_ir.get(), [&](Stmt *s) {
38 return s->is<ReturnStmt>();
39 }).size() > 1) {
40 TI_WARN(
41 "Multiple returns in function \"{}\" may not be handled "
42 "properly.\n{}",
43 func->get_name(), stmt->tb);
44 }
45 // Use a local variable to store the return value
46 auto *return_address = inlined_ir->as<Block>()->insert(
47 Stmt::make<AllocaStmt>(func->rets[0].dt), /*location=*/0);
48 irpass::replace_and_insert_statements(
49 inlined_ir.get(),
50 /*filter=*/[&](Stmt *s) { return s->is<ReturnStmt>(); },
51 /*generator=*/
52 [&](Stmt *s) {
53 TI_ASSERT(s->as<ReturnStmt>()->values.size() == 1);
54 return Stmt::make<LocalStoreStmt>(return_address,
55 s->as<ReturnStmt>()->values[0]);
56 });
57 modifier_.insert_before(stmt,
58 std::move(inlined_ir->as<Block>()->statements));
59 // Load the return value here
60 modifier_.replace_with(stmt, Stmt::make<LocalLoadStmt>(return_address));
61 }
62 }
63
64 static bool run(IRNode *node) {
65 Inliner inliner;
66 bool modified = false;
67 while (true) {
68 node->accept(&inliner);
69 if (inliner.modifier_.modify_ir())
70 modified = true;
71 else
72 break;
73 }
74 return modified;
75 }
76
77 private:
78 DelayedIRModifier modifier_;
79};
80
81const PassID InliningPass::id = "InliningPass";
82
83namespace irpass {
84
85bool inlining(IRNode *root,
86 const CompileConfig &config,
87 const InliningPass::Args &args) {
88 TI_AUTO_PROF;
89 return Inliner::run(root);
90}
91
92} // namespace irpass
93
94} // namespace taichi::lang
95