1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/transforms/inline.cc |
22 | * \brief Global function inliner. It contains the following steps: |
23 | * |
24 | * - Preprocessing: eligibility checking. Only inline the functions that can |
25 | * be inlined. We currently only use simple rules to make the decision. No |
26 | * profitibility analysis is available for now. |
27 | * |
28 | * - Inline: replace the call with a function or the function body depending on |
29 | * the attribute of the callee function. For example, we return the function |
30 | * node when it doesn't use default compiler, i.e. llvm. This is because these |
31 | * functions are packed to be offloaded to external codegen. |
32 | * |
33 | * - Postprocessing: remove the replaced functions that have no reference. |
34 | */ |
35 | |
36 | #include <tvm/relay/attrs/annotation.h> |
37 | #include <tvm/relay/expr.h> |
38 | #include <tvm/relay/expr_functor.h> |
39 | #include <tvm/relay/transform.h> |
40 | #include <tvm/runtime/logging.h> |
41 | |
42 | #include <string> |
43 | #include <unordered_set> |
44 | |
45 | #include "../analysis/call_graph.h" |
46 | #include "../op/call/call.h" |
47 | |
48 | using namespace tvm::runtime; |
49 | |
50 | namespace tvm { |
51 | namespace relay { |
52 | |
53 | class Inliner : ExprMutator { |
54 | public: |
55 | explicit Inliner(CallGraphEntry* cur_node, CallGraphNode* call_graph) |
56 | : cur_node_(cur_node), call_graph_(call_graph) {} |
57 | |
58 | Expr VisitExpr_(const CallNode* call_node) final { |
59 | // We can work with calls in both pre- and post-lowered form. |
60 | Call vanilla_call = GetAnyCall(call_node); |
61 | |
62 | const auto* global_var_node = vanilla_call->op.as<GlobalVarNode>(); |
63 | if (global_var_node) { |
64 | GlobalVar gv = GetRef<GlobalVar>(global_var_node); |
65 | auto* cg_node = (*call_graph_)[gv->name_hint]; |
66 | if (CanInline(cg_node)) { |
67 | Array<Expr> new_args; |
68 | new_args.reserve(vanilla_call->args.size()); |
69 | for (auto arg : vanilla_call->args) { |
70 | new_args.push_back(VisitExpr(arg)); |
71 | } |
72 | // TODO(mbs): Does not handle multiple calls to the same global function. |
73 | cur_node_->RemoveCallTo(gv); |
74 | return MakeNewExpr(gv, new_args, GetRef<Call>(call_node)); |
75 | } |
76 | // else: fallthrough |
77 | } |
78 | // else: fallthrough |
79 | |
80 | // If not calling a global function then nothing to inline. |
81 | return ExprMutator::VisitExpr_(call_node); |
82 | } |
83 | |
84 | Expr VisitExpr_(const GlobalVarNode* gvn) final { |
85 | GlobalVar gv = GetRef<GlobalVar>(gvn); |
86 | auto* cg_node = (*call_graph_)[gv->name_hint]; |
87 | if (CanInline(cg_node)) { |
88 | cur_node_->RemoveCallTo(gv); |
89 | return MakeNewExpr(gv, {}, GetRef<GlobalVar>(gvn)); |
90 | } |
91 | return ExprMutator::VisitExpr_(gvn); |
92 | } |
93 | |
94 | Function Inline(const Function& func) { |
95 | return WithFields(func, func->params, VisitExpr(func->body)); |
96 | } |
97 | |
98 | private: |
99 | bool CanInline(const CallGraphEntry* cg_node) { |
100 | // The node must be a leaf node and it cannot be recursive. |
101 | if (!cg_node->empty() || cg_node->IsRecursive()) return false; |
102 | |
103 | auto base_func = call_graph_->GetGlobalFunction(cg_node->GetGlobalVar()); |
104 | const auto* function_node = base_func.as<FunctionNode>(); |
105 | if (!function_node) { |
106 | // Can't inline PrimFuncs! |
107 | return false; |
108 | } |
109 | // The body of a global functions must be defined. |
110 | if (!function_node->body.defined()) return false; |
111 | |
112 | // The function must be annotated with the inline attribute. |
113 | // (Note that partitioned functions and external functions do not have this attribute!) |
114 | if (!function_node->HasNonzeroAttr(attr::kInline)) return false; |
115 | |
116 | // The function is not able to be inlined if any callee under the CallGraph |
117 | // of this function cannot be inlined. |
118 | for (const auto& it : *cg_node) { |
119 | if (!CanInline(it.second)) { |
120 | return false; |
121 | } |
122 | } |
123 | |
124 | return true; |
125 | } |
126 | |
127 | // Make a new Relay expression to replace \p expr. |
128 | Expr MakeNewExpr(const GlobalVar& global, const Array<Expr>& args, const Expr& expr) { |
129 | ICHECK(expr->IsInstance<CallNode>() || expr->IsInstance<GlobalVarNode>()); |
130 | auto base_func = call_graph_->GetGlobalFunction(global); |
131 | const auto* fn = base_func.as<FunctionNode>(); |
132 | ICHECK(fn) << "Expected to work on a Relay function." ; |
133 | |
134 | // There is an inconsistency here, the function itself gets shallow-copied but the body is not |
135 | // shallow-copied. |
136 | auto func = Function(fn->params, fn->body, fn->ret_type, fn->type_params, fn->attrs); |
137 | // Inline the function body to the caller if this function uses default |
138 | // compiler, i.e. no external codegen is needed. |
139 | if (!func->GetAttr<String>(attr::kCompiler).defined() && !func->HasNonzeroAttr(attr::kExtern)) { |
140 | ICHECK_EQ(func->params.size(), args.size()) |
141 | << "Mismatch found in the number of parameters and call args" ; |
142 | // Bind the parameters with call args. |
143 | Map<Var, Expr> bind_map; |
144 | for (size_t i = 0; i < args.size(); i++) { |
145 | bind_map.Set(fn->params[i], args[i]); |
146 | } |
147 | if (const auto* gvn = expr.as<GlobalVarNode>()) { |
148 | auto ret_type = gvn->checked_type(); |
149 | // Cannot replace TensorType/TensorTupleType with FuncType. Therefore, |
150 | // we simply inline the function as a closure instead of directly using |
151 | // its body when the global var returns FuncType. |
152 | return ret_type->IsInstance<FuncTypeNode>() ? std::move(func) : func->body; |
153 | } else { |
154 | ICHECK(expr->IsInstance<CallNode>()); |
155 | return Bind(func->body, bind_map); |
156 | } |
157 | } else if (const auto* call_node = expr.as<CallNode>()) { |
158 | return Call(func, args, call_node->attrs, call_node->type_args); |
159 | } else { |
160 | return std::move(func); |
161 | } |
162 | } |
163 | |
164 | /*! |
165 | * \brief The current call graph entry that is being handled. Each entry |
166 | * contains a global function. |
167 | */ |
168 | CallGraphEntry* cur_node_; |
169 | /*! \brief The call graph that is used for global function lookup. */ |
170 | const CallGraphNode* call_graph_; |
171 | }; |
172 | |
173 | IRModule Inline(const IRModule& module) { |
174 | CallGraph cg(module); |
175 | auto topo = cg->TopologicalOrder(); |
176 | // Get the reverse topological order of the global functions. |
177 | std::reverse(topo.begin(), topo.end()); |
178 | // Cache the functions that are originally entries. These functions will |
179 | // remain in the module after inlining. |
180 | std::unordered_set<CallGraphEntry*> original_entry; |
181 | |
182 | for (auto* it : topo) { |
183 | if (it->GetRefCount() == 0) original_entry.emplace(it); |
184 | // Skip the leaf calls and the recursive calls that don't call other |
185 | // functions. |
186 | if (it->empty() || (it->IsRecursive() && it->size() == 1)) continue; |
187 | auto base_func = module->Lookup(it->GetNameHint()); |
188 | if (const auto* fn = base_func.as<FunctionNode>()) { |
189 | auto func = GetRef<Function>(fn); |
190 | auto new_func = Inliner(it, cg.operator->()).Inline(func); |
191 | // TODO(zhiics) Maybe move this to CallGraph, but updating function from |
192 | // CallGraph arbitarily may lead to incorrect CallGraph. |
193 | cg->module->Update(it->GetGlobalVar(), new_func); |
194 | } |
195 | } |
196 | |
197 | // Clean up the functions that are inlined and have no reference. |
198 | for (auto* cgn : topo) { |
199 | // Skip recursive functions and entry functions even if they are marked as |
200 | // `inline`. |
201 | if (cgn->IsRecursive() || original_entry.count(cgn)) continue; |
202 | auto base_func = cg->GetGlobalFunction(cgn->GetGlobalVar()); |
203 | // Skip calls to PrimFuncs since they can't be inlined. |
204 | if (const auto* fn = base_func.as<FunctionNode>()) { |
205 | auto func = GetRef<Function>(fn); |
206 | if (func->HasNonzeroAttr(attr::kInline)) { |
207 | ICHECK_EQ(cgn->GetRefCount(), 0U) |
208 | << cgn->GetNameHint() << " is marked as inline but not inlined." ; |
209 | cgn->CleanCallGraphEntries(); |
210 | cg->RemoveGlobalVarFromModule(cgn, /*update_call_graph*/ true); |
211 | } |
212 | } |
213 | } |
214 | |
215 | return cg->module; |
216 | } |
217 | |
218 | namespace transform { |
219 | |
220 | Pass Inline() { |
221 | runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = |
222 | [=](IRModule m, PassContext pc) { return relay::Inline(m); }; |
223 | return CreateModulePass(pass_func, 1, "InlineGlobals" , {}); |
224 | } |
225 | |
226 | TVM_REGISTER_GLOBAL("relay._transform.Inline" ).set_body_typed(Inline); |
227 | |
228 | } // namespace transform |
229 | |
230 | } // namespace relay |
231 | } // namespace tvm |
232 | |