1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/transforms/inline.cc
22 * \brief Global function inliner. It contains the following steps:
23 *
24 * - Preprocessing: eligibility checking. Only inline the functions that can
25 * be inlined. We currently only use simple rules to make the decision. No
26 * profitibility analysis is available for now.
27 *
28 * - Inline: replace the call with a function or the function body depending on
29 * the attribute of the callee function. For example, we return the function
30 * node when it doesn't use default compiler, i.e. llvm. This is because these
31 * functions are packed to be offloaded to external codegen.
32 *
33 * - Postprocessing: remove the replaced functions that have no reference.
34 */
35
36#include <tvm/relay/attrs/annotation.h>
37#include <tvm/relay/expr.h>
38#include <tvm/relay/expr_functor.h>
39#include <tvm/relay/transform.h>
40#include <tvm/runtime/logging.h>
41
42#include <string>
43#include <unordered_set>
44
45#include "../analysis/call_graph.h"
46#include "../op/call/call.h"
47
48using namespace tvm::runtime;
49
50namespace tvm {
51namespace relay {
52
53class Inliner : ExprMutator {
54 public:
55 explicit Inliner(CallGraphEntry* cur_node, CallGraphNode* call_graph)
56 : cur_node_(cur_node), call_graph_(call_graph) {}
57
58 Expr VisitExpr_(const CallNode* call_node) final {
59 // We can work with calls in both pre- and post-lowered form.
60 Call vanilla_call = GetAnyCall(call_node);
61
62 const auto* global_var_node = vanilla_call->op.as<GlobalVarNode>();
63 if (global_var_node) {
64 GlobalVar gv = GetRef<GlobalVar>(global_var_node);
65 auto* cg_node = (*call_graph_)[gv->name_hint];
66 if (CanInline(cg_node)) {
67 Array<Expr> new_args;
68 new_args.reserve(vanilla_call->args.size());
69 for (auto arg : vanilla_call->args) {
70 new_args.push_back(VisitExpr(arg));
71 }
72 // TODO(mbs): Does not handle multiple calls to the same global function.
73 cur_node_->RemoveCallTo(gv);
74 return MakeNewExpr(gv, new_args, GetRef<Call>(call_node));
75 }
76 // else: fallthrough
77 }
78 // else: fallthrough
79
80 // If not calling a global function then nothing to inline.
81 return ExprMutator::VisitExpr_(call_node);
82 }
83
84 Expr VisitExpr_(const GlobalVarNode* gvn) final {
85 GlobalVar gv = GetRef<GlobalVar>(gvn);
86 auto* cg_node = (*call_graph_)[gv->name_hint];
87 if (CanInline(cg_node)) {
88 cur_node_->RemoveCallTo(gv);
89 return MakeNewExpr(gv, {}, GetRef<GlobalVar>(gvn));
90 }
91 return ExprMutator::VisitExpr_(gvn);
92 }
93
94 Function Inline(const Function& func) {
95 return WithFields(func, func->params, VisitExpr(func->body));
96 }
97
98 private:
99 bool CanInline(const CallGraphEntry* cg_node) {
100 // The node must be a leaf node and it cannot be recursive.
101 if (!cg_node->empty() || cg_node->IsRecursive()) return false;
102
103 auto base_func = call_graph_->GetGlobalFunction(cg_node->GetGlobalVar());
104 const auto* function_node = base_func.as<FunctionNode>();
105 if (!function_node) {
106 // Can't inline PrimFuncs!
107 return false;
108 }
109 // The body of a global functions must be defined.
110 if (!function_node->body.defined()) return false;
111
112 // The function must be annotated with the inline attribute.
113 // (Note that partitioned functions and external functions do not have this attribute!)
114 if (!function_node->HasNonzeroAttr(attr::kInline)) return false;
115
116 // The function is not able to be inlined if any callee under the CallGraph
117 // of this function cannot be inlined.
118 for (const auto& it : *cg_node) {
119 if (!CanInline(it.second)) {
120 return false;
121 }
122 }
123
124 return true;
125 }
126
127 // Make a new Relay expression to replace \p expr.
128 Expr MakeNewExpr(const GlobalVar& global, const Array<Expr>& args, const Expr& expr) {
129 ICHECK(expr->IsInstance<CallNode>() || expr->IsInstance<GlobalVarNode>());
130 auto base_func = call_graph_->GetGlobalFunction(global);
131 const auto* fn = base_func.as<FunctionNode>();
132 ICHECK(fn) << "Expected to work on a Relay function.";
133
134 // There is an inconsistency here, the function itself gets shallow-copied but the body is not
135 // shallow-copied.
136 auto func = Function(fn->params, fn->body, fn->ret_type, fn->type_params, fn->attrs);
137 // Inline the function body to the caller if this function uses default
138 // compiler, i.e. no external codegen is needed.
139 if (!func->GetAttr<String>(attr::kCompiler).defined() && !func->HasNonzeroAttr(attr::kExtern)) {
140 ICHECK_EQ(func->params.size(), args.size())
141 << "Mismatch found in the number of parameters and call args";
142 // Bind the parameters with call args.
143 Map<Var, Expr> bind_map;
144 for (size_t i = 0; i < args.size(); i++) {
145 bind_map.Set(fn->params[i], args[i]);
146 }
147 if (const auto* gvn = expr.as<GlobalVarNode>()) {
148 auto ret_type = gvn->checked_type();
149 // Cannot replace TensorType/TensorTupleType with FuncType. Therefore,
150 // we simply inline the function as a closure instead of directly using
151 // its body when the global var returns FuncType.
152 return ret_type->IsInstance<FuncTypeNode>() ? std::move(func) : func->body;
153 } else {
154 ICHECK(expr->IsInstance<CallNode>());
155 return Bind(func->body, bind_map);
156 }
157 } else if (const auto* call_node = expr.as<CallNode>()) {
158 return Call(func, args, call_node->attrs, call_node->type_args);
159 } else {
160 return std::move(func);
161 }
162 }
163
164 /*!
165 * \brief The current call graph entry that is being handled. Each entry
166 * contains a global function.
167 */
168 CallGraphEntry* cur_node_;
169 /*! \brief The call graph that is used for global function lookup. */
170 const CallGraphNode* call_graph_;
171};
172
173IRModule Inline(const IRModule& module) {
174 CallGraph cg(module);
175 auto topo = cg->TopologicalOrder();
176 // Get the reverse topological order of the global functions.
177 std::reverse(topo.begin(), topo.end());
178 // Cache the functions that are originally entries. These functions will
179 // remain in the module after inlining.
180 std::unordered_set<CallGraphEntry*> original_entry;
181
182 for (auto* it : topo) {
183 if (it->GetRefCount() == 0) original_entry.emplace(it);
184 // Skip the leaf calls and the recursive calls that don't call other
185 // functions.
186 if (it->empty() || (it->IsRecursive() && it->size() == 1)) continue;
187 auto base_func = module->Lookup(it->GetNameHint());
188 if (const auto* fn = base_func.as<FunctionNode>()) {
189 auto func = GetRef<Function>(fn);
190 auto new_func = Inliner(it, cg.operator->()).Inline(func);
191 // TODO(zhiics) Maybe move this to CallGraph, but updating function from
192 // CallGraph arbitarily may lead to incorrect CallGraph.
193 cg->module->Update(it->GetGlobalVar(), new_func);
194 }
195 }
196
197 // Clean up the functions that are inlined and have no reference.
198 for (auto* cgn : topo) {
199 // Skip recursive functions and entry functions even if they are marked as
200 // `inline`.
201 if (cgn->IsRecursive() || original_entry.count(cgn)) continue;
202 auto base_func = cg->GetGlobalFunction(cgn->GetGlobalVar());
203 // Skip calls to PrimFuncs since they can't be inlined.
204 if (const auto* fn = base_func.as<FunctionNode>()) {
205 auto func = GetRef<Function>(fn);
206 if (func->HasNonzeroAttr(attr::kInline)) {
207 ICHECK_EQ(cgn->GetRefCount(), 0U)
208 << cgn->GetNameHint() << " is marked as inline but not inlined.";
209 cgn->CleanCallGraphEntries();
210 cg->RemoveGlobalVarFromModule(cgn, /*update_call_graph*/ true);
211 }
212 }
213 }
214
215 return cg->module;
216}
217
218namespace transform {
219
220Pass Inline() {
221 runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
222 [=](IRModule m, PassContext pc) { return relay::Inline(m); };
223 return CreateModulePass(pass_func, 1, "InlineGlobals", {});
224}
225
226TVM_REGISTER_GLOBAL("relay._transform.Inline").set_body_typed(Inline);
227
228} // namespace transform
229
230} // namespace relay
231} // namespace tvm
232