1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/relay/backend/vm/lambda_lift.cc |
22 | * \brief Lift all local functions into global functions. |
23 | */ |
24 | |
25 | #include <tvm/node/structural_equal.h> |
26 | #include <tvm/node/structural_hash.h> |
27 | #include <tvm/relay/analysis.h> |
28 | #include <tvm/relay/expr.h> |
29 | #include <tvm/relay/transform.h> |
30 | #include <tvm/runtime/logging.h> |
31 | |
32 | #include <iostream> |
33 | #include <vector> |
34 | |
35 | #include "../../op/annotation/annotation.h" |
36 | #include "../../transforms/device_aware_visitors.h" |
37 | |
38 | using namespace tvm::runtime; |
39 | |
40 | namespace tvm { |
41 | namespace relay { |
42 | namespace vm { |
43 | |
44 | inline std::string GenerateName(const Function& func) { |
45 | size_t hash = tvm::StructuralHash()(func); |
46 | return std::string("lifted_name" ) + std::to_string(hash); |
47 | } |
48 | |
49 | bool IsClosure(const Function& func) { return func->HasNonzeroAttr(attr::kClosure); } |
50 | |
51 | Function MarkClosure(Function func) { |
52 | return WithAttr(std::move(func), attr::kClosure, tvm::Integer(1)); |
53 | } |
54 | |
55 | /* The goal of this class is to lift out any nested functions into top-level |
56 | * functions. |
57 | * |
58 | * We will lift a function out into a global which takes the set of the free |
59 | * vars and then return the new created function. |
60 | */ |
61 | class LambdaLifter : public transform::DeviceAwareExprMutator { |
62 | public: |
63 | explicit LambdaLifter(const IRModule& module) |
64 | : transform::DeviceAwareExprMutator(module), module_(module) {} |
65 | |
66 | std::pair<Var, Expr> PreVisitLetBinding_(const Var& var, const Expr& value) final { |
67 | bool is_lambda = false; |
68 | if (const auto* func_node = value.as<FunctionNode>()) { |
69 | if (!func_node->HasNonzeroAttr(attr::kPrimitive)) { |
70 | is_lambda = true; |
71 | this->letrec_.push_back(var); |
72 | } |
73 | } |
74 | Expr new_value = this->VisitExpr(value); |
75 | |
76 | if (is_lambda) { |
77 | this->letrec_.pop_back(); |
78 | } |
79 | return {var, new_value}; |
80 | } |
81 | |
82 | Expr DeviceAwareVisitExpr_(const CallNode* call_node) final { |
83 | auto call = Downcast<Call>(DeviceAwareExprMutator::DeviceAwareVisitExpr_(call_node)); |
84 | if (auto var_node = call_node->op.as<VarNode>()) { |
85 | auto var = GetRef<Var>(var_node); |
86 | if (!letrec_.empty() && var == letrec_.back()) { |
87 | auto it = lambda_map_.find(var); |
88 | ICHECK(it != lambda_map_.end()); |
89 | return Call(it->second, call->args, call_node->attrs, call_node->type_args); |
90 | } |
91 | } |
92 | return std::move(call); |
93 | } |
94 | |
95 | Expr DeviceAwareVisitExpr_(const FunctionNode* func_node) final { |
96 | auto func = GetRef<Function>(func_node); |
97 | |
98 | if (func->HasNonzeroAttr(attr::kPrimitive)) { |
99 | // We should not transform primitive functions. |
100 | return std::move(func); |
101 | } |
102 | |
103 | if (function_nesting() == 1) { |
104 | // We don't need to lift global functions. |
105 | return WithFields(GetRef<Function>(func_node), func_node->params, VisitExpr(func_node->body)); |
106 | } |
107 | |
108 | auto name = GenerateName(func); |
109 | auto global = GlobalVar(name); |
110 | auto free_vars = FreeVars(func); |
111 | auto free_type_vars = FreeTypeVars(func, module_); |
112 | |
113 | Array<Var> captured_vars; |
114 | bool recursive = false; |
115 | for (const auto& var : free_vars) { |
116 | if (!letrec_.empty() && var == letrec_.back()) { |
117 | recursive = true; |
118 | continue; |
119 | } |
120 | captured_vars.push_back(var); |
121 | } |
122 | |
123 | // Freshen all the captured vars. |
124 | Array<Var> typed_captured_vars; |
125 | Map<Var, Expr> rebinding_map; |
126 | for (auto free_var : captured_vars) { |
127 | auto var = Var(free_var->name_hint(), free_var->checked_type()); |
128 | var->virtual_device_ = GetVirtualDevice(free_var); |
129 | typed_captured_vars.push_back(var); |
130 | rebinding_map.Set(free_var, var); |
131 | } |
132 | |
133 | VirtualDevice result_virtual_device = GetVirtualDevice(func_node->body); |
134 | |
135 | if (recursive) { |
136 | if (!captured_vars.empty()) { |
137 | Array<Expr> fvs; |
138 | for (auto fv : captured_vars) { |
139 | fvs.push_back(fv); |
140 | } |
141 | lambda_map_.emplace(letrec_.back(), Call(global, fvs)); |
142 | } else { |
143 | lambda_map_.emplace(letrec_.back(), global); |
144 | } |
145 | } |
146 | |
147 | auto body = Downcast<Function>(DeviceAwareExprMutator::DeviceAwareVisitExpr_(func_node)); |
148 | |
149 | // When performing this optimization there are two cases. |
150 | // |
151 | // The first case in which we have no free variables |
152 | // we can just lift the function into the global |
153 | // environment without needing to allocate a closure. |
154 | // |
155 | // |
156 | // The second case requires that we generate a special |
157 | // function which makes a distinction between allocating |
158 | // a closure, and then the code for the closure. |
159 | // |
160 | // We represent a closure allocation by lifting the |
161 | // closure to a global function which takes its |
162 | // captured arguments and then directly returns |
163 | // the function representing the closure's code. |
164 | // |
165 | // When we generate code later on a call to the "outer" |
166 | // function marked as a closure is used to emit allocation |
167 | // code for the closure's environment. |
168 | // |
169 | // The "inner" function should be used to generate the |
170 | // code for the closure. |
171 | Function lifted_func; |
172 | if (captured_vars.empty() && free_type_vars.empty()) { |
173 | lifted_func = Function(body->params, body->body, body->ret_type, body->type_params, |
174 | body->attrs, body->span); |
175 | // We also need to copy the virtual device |
176 | lifted_func->virtual_device_ = body->virtual_device(); |
177 | } else { |
178 | // When a closure is locally bound in a program, we have its full type information |
179 | // avalible to us. |
180 | // |
181 | // If we lift the closure out of its bound context it may have free variables which |
182 | // do not have type annotations. |
183 | // |
184 | // In this case we first type check the program assigning a type to all sub-expressions. |
185 | // |
186 | // We then change the un-annotated free variables into annotated free variables, use |
187 | // bind to go from unannotated free variables -> annotated free variables and then |
188 | // construct the "closure" function with fully annotated arguments, no longer relying |
189 | // on type inference. |
190 | size_t before_arity = body->params.size(); |
191 | VLOG(9) << "Binding " << rebinding_map << " into\n" << PrettyPrint(body->body); |
192 | auto rebound_body = WithFields(func, func->params, Bind(body->body, rebinding_map)); |
193 | size_t after_arity = rebound_body->params.size(); |
194 | CHECK_EQ(before_arity, after_arity); |
195 | lifted_func = |
196 | Function(typed_captured_vars, rebound_body, /*ret_type=*/func->func_type_annotation(), |
197 | free_type_vars, /*attrs=*/{}, func->span); |
198 | lifted_func->virtual_device_ = result_virtual_device; |
199 | lifted_func = MarkClosure(lifted_func); |
200 | } |
201 | |
202 | ICHECK(lifted_func.defined()); |
203 | |
204 | if (module_->ContainGlobalVar(name)) { |
205 | const auto existing_func = module_->Lookup(name); |
206 | ICHECK(tvm::StructuralEqual()(lifted_func, existing_func)) |
207 | << "lifted function hash collision" ; |
208 | // If an identical function already exists, use its global var. |
209 | global = module_->GetGlobalVar(name); |
210 | } else { |
211 | // Add the lifted function to the module. |
212 | module_->Add(global, lifted_func); |
213 | } |
214 | |
215 | if (captured_vars.empty()) { |
216 | return std::move(global); |
217 | } else { |
218 | // If we need to allocate a closure, |
219 | // we pass the variables in its environment here. |
220 | Array<Expr> fvs; |
221 | for (auto fv : captured_vars) { |
222 | fvs.push_back(fv); |
223 | } |
224 | return Call(global, fvs); |
225 | } |
226 | } |
227 | |
228 | IRModule Lift() { |
229 | // There is an ordering bug here. |
230 | auto glob_funcs = module_->functions; |
231 | for (auto pair : glob_funcs) { |
232 | if (auto* n = pair.second.as<FunctionNode>()) { |
233 | if (n->GetAttr<String>(attr::kCompiler).defined()) continue; |
234 | auto func = GetRef<Function>(n); |
235 | module_->Add(pair.first, Downcast<Function>(Mutate(func)), /*update=*/true); |
236 | } |
237 | } |
238 | return module_; |
239 | } |
240 | |
241 | private: |
242 | std::unordered_map<Var, Expr, ObjectPtrHash, ObjectPtrEqual> lambda_map_; |
243 | std::vector<Var> letrec_; |
244 | IRModule module_; |
245 | }; |
246 | |
247 | } // namespace vm |
248 | |
249 | namespace transform { |
250 | |
251 | Pass LambdaLift() { |
252 | runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = |
253 | [=](IRModule m, PassContext pc) { return relay::vm::LambdaLifter(m).Lift(); }; |
254 | return CreateModulePass(pass_func, 1, "LambdaLift" , {}); |
255 | } |
256 | |
257 | TVM_REGISTER_GLOBAL("relay._transform.LambdaLift" ).set_body_typed(LambdaLift); |
258 | |
259 | } // namespace transform |
260 | |
261 | } // namespace relay |
262 | } // namespace tvm |
263 | |