1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/relay/backend/vm/lambda_lift.cc
22 * \brief Lift all local functions into global functions.
23 */
24
25#include <tvm/node/structural_equal.h>
26#include <tvm/node/structural_hash.h>
27#include <tvm/relay/analysis.h>
28#include <tvm/relay/expr.h>
29#include <tvm/relay/transform.h>
30#include <tvm/runtime/logging.h>
31
32#include <iostream>
33#include <vector>
34
35#include "../../op/annotation/annotation.h"
36#include "../../transforms/device_aware_visitors.h"
37
38using namespace tvm::runtime;
39
40namespace tvm {
41namespace relay {
42namespace vm {
43
44inline std::string GenerateName(const Function& func) {
45 size_t hash = tvm::StructuralHash()(func);
46 return std::string("lifted_name") + std::to_string(hash);
47}
48
49bool IsClosure(const Function& func) { return func->HasNonzeroAttr(attr::kClosure); }
50
51Function MarkClosure(Function func) {
52 return WithAttr(std::move(func), attr::kClosure, tvm::Integer(1));
53}
54
55/* The goal of this class is to lift out any nested functions into top-level
56 * functions.
57 *
58 * We will lift a function out into a global which takes the set of the free
59 * vars and then return the new created function.
60 */
61class LambdaLifter : public transform::DeviceAwareExprMutator {
62 public:
63 explicit LambdaLifter(const IRModule& module)
64 : transform::DeviceAwareExprMutator(module), module_(module) {}
65
66 std::pair<Var, Expr> PreVisitLetBinding_(const Var& var, const Expr& value) final {
67 bool is_lambda = false;
68 if (const auto* func_node = value.as<FunctionNode>()) {
69 if (!func_node->HasNonzeroAttr(attr::kPrimitive)) {
70 is_lambda = true;
71 this->letrec_.push_back(var);
72 }
73 }
74 Expr new_value = this->VisitExpr(value);
75
76 if (is_lambda) {
77 this->letrec_.pop_back();
78 }
79 return {var, new_value};
80 }
81
82 Expr DeviceAwareVisitExpr_(const CallNode* call_node) final {
83 auto call = Downcast<Call>(DeviceAwareExprMutator::DeviceAwareVisitExpr_(call_node));
84 if (auto var_node = call_node->op.as<VarNode>()) {
85 auto var = GetRef<Var>(var_node);
86 if (!letrec_.empty() && var == letrec_.back()) {
87 auto it = lambda_map_.find(var);
88 ICHECK(it != lambda_map_.end());
89 return Call(it->second, call->args, call_node->attrs, call_node->type_args);
90 }
91 }
92 return std::move(call);
93 }
94
95 Expr DeviceAwareVisitExpr_(const FunctionNode* func_node) final {
96 auto func = GetRef<Function>(func_node);
97
98 if (func->HasNonzeroAttr(attr::kPrimitive)) {
99 // We should not transform primitive functions.
100 return std::move(func);
101 }
102
103 if (function_nesting() == 1) {
104 // We don't need to lift global functions.
105 return WithFields(GetRef<Function>(func_node), func_node->params, VisitExpr(func_node->body));
106 }
107
108 auto name = GenerateName(func);
109 auto global = GlobalVar(name);
110 auto free_vars = FreeVars(func);
111 auto free_type_vars = FreeTypeVars(func, module_);
112
113 Array<Var> captured_vars;
114 bool recursive = false;
115 for (const auto& var : free_vars) {
116 if (!letrec_.empty() && var == letrec_.back()) {
117 recursive = true;
118 continue;
119 }
120 captured_vars.push_back(var);
121 }
122
123 // Freshen all the captured vars.
124 Array<Var> typed_captured_vars;
125 Map<Var, Expr> rebinding_map;
126 for (auto free_var : captured_vars) {
127 auto var = Var(free_var->name_hint(), free_var->checked_type());
128 var->virtual_device_ = GetVirtualDevice(free_var);
129 typed_captured_vars.push_back(var);
130 rebinding_map.Set(free_var, var);
131 }
132
133 VirtualDevice result_virtual_device = GetVirtualDevice(func_node->body);
134
135 if (recursive) {
136 if (!captured_vars.empty()) {
137 Array<Expr> fvs;
138 for (auto fv : captured_vars) {
139 fvs.push_back(fv);
140 }
141 lambda_map_.emplace(letrec_.back(), Call(global, fvs));
142 } else {
143 lambda_map_.emplace(letrec_.back(), global);
144 }
145 }
146
147 auto body = Downcast<Function>(DeviceAwareExprMutator::DeviceAwareVisitExpr_(func_node));
148
149 // When performing this optimization there are two cases.
150 //
151 // The first case in which we have no free variables
152 // we can just lift the function into the global
153 // environment without needing to allocate a closure.
154 //
155 //
156 // The second case requires that we generate a special
157 // function which makes a distinction between allocating
158 // a closure, and then the code for the closure.
159 //
160 // We represent a closure allocation by lifting the
161 // closure to a global function which takes its
162 // captured arguments and then directly returns
163 // the function representing the closure's code.
164 //
165 // When we generate code later on a call to the "outer"
166 // function marked as a closure is used to emit allocation
167 // code for the closure's environment.
168 //
169 // The "inner" function should be used to generate the
170 // code for the closure.
171 Function lifted_func;
172 if (captured_vars.empty() && free_type_vars.empty()) {
173 lifted_func = Function(body->params, body->body, body->ret_type, body->type_params,
174 body->attrs, body->span);
175 // We also need to copy the virtual device
176 lifted_func->virtual_device_ = body->virtual_device();
177 } else {
178 // When a closure is locally bound in a program, we have its full type information
179 // avalible to us.
180 //
181 // If we lift the closure out of its bound context it may have free variables which
182 // do not have type annotations.
183 //
184 // In this case we first type check the program assigning a type to all sub-expressions.
185 //
186 // We then change the un-annotated free variables into annotated free variables, use
187 // bind to go from unannotated free variables -> annotated free variables and then
188 // construct the "closure" function with fully annotated arguments, no longer relying
189 // on type inference.
190 size_t before_arity = body->params.size();
191 VLOG(9) << "Binding " << rebinding_map << " into\n" << PrettyPrint(body->body);
192 auto rebound_body = WithFields(func, func->params, Bind(body->body, rebinding_map));
193 size_t after_arity = rebound_body->params.size();
194 CHECK_EQ(before_arity, after_arity);
195 lifted_func =
196 Function(typed_captured_vars, rebound_body, /*ret_type=*/func->func_type_annotation(),
197 free_type_vars, /*attrs=*/{}, func->span);
198 lifted_func->virtual_device_ = result_virtual_device;
199 lifted_func = MarkClosure(lifted_func);
200 }
201
202 ICHECK(lifted_func.defined());
203
204 if (module_->ContainGlobalVar(name)) {
205 const auto existing_func = module_->Lookup(name);
206 ICHECK(tvm::StructuralEqual()(lifted_func, existing_func))
207 << "lifted function hash collision";
208 // If an identical function already exists, use its global var.
209 global = module_->GetGlobalVar(name);
210 } else {
211 // Add the lifted function to the module.
212 module_->Add(global, lifted_func);
213 }
214
215 if (captured_vars.empty()) {
216 return std::move(global);
217 } else {
218 // If we need to allocate a closure,
219 // we pass the variables in its environment here.
220 Array<Expr> fvs;
221 for (auto fv : captured_vars) {
222 fvs.push_back(fv);
223 }
224 return Call(global, fvs);
225 }
226 }
227
228 IRModule Lift() {
229 // There is an ordering bug here.
230 auto glob_funcs = module_->functions;
231 for (auto pair : glob_funcs) {
232 if (auto* n = pair.second.as<FunctionNode>()) {
233 if (n->GetAttr<String>(attr::kCompiler).defined()) continue;
234 auto func = GetRef<Function>(n);
235 module_->Add(pair.first, Downcast<Function>(Mutate(func)), /*update=*/true);
236 }
237 }
238 return module_;
239 }
240
241 private:
242 std::unordered_map<Var, Expr, ObjectPtrHash, ObjectPtrEqual> lambda_map_;
243 std::vector<Var> letrec_;
244 IRModule module_;
245};
246
247} // namespace vm
248
249namespace transform {
250
251Pass LambdaLift() {
252 runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
253 [=](IRModule m, PassContext pc) { return relay::vm::LambdaLifter(m).Lift(); };
254 return CreateModulePass(pass_func, 1, "LambdaLift", {});
255}
256
257TVM_REGISTER_GLOBAL("relay._transform.LambdaLift").set_body_typed(LambdaLift);
258
259} // namespace transform
260
261} // namespace relay
262} // namespace tvm
263