1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * |
22 | * \file defunctionalization.cc |
23 | * |
24 | * \brief Defunctionalization for Relay IR |
25 | * |
26 | * This pass transforms a higher-order program into a first-order program with defunctionalization. |
27 | * This means that all higher order functions (i.e functions that take function arguments or return |
28 | * functions) should be transformed into a semantically equivalent first order one. |
29 | * |
30 | * This pass implements a basic typed defunctionalization method. |
31 | * All higher order functions are cloned and specialized (so that there are no type params). |
32 | * Function type arguments are encoded as datatypes and a helper `apply` function is used |
33 | * to "call" them. |
34 | * |
35 | * For example, take the following higher order program: |
36 | * fun map F y = case y of |
37 | * Nil => Nil |
38 | * | Cons(x, XS) => Cons(F z, map F XS) |
39 | * fun addone 1 = map (\x -> \x + 1) 1 |
40 | * |
41 | * where `addone` is our program. |
42 | * When we call the `map` function, we see that it is a higher-order function, |
43 | * but we can clone `map ` function and specialize it with the type_params of the call. |
44 | * In addition, our function argument `(\x -> \x + 1)` will be encoded as a datatype constructor, |
45 | * which we will call `incr`, and all calls to `F` in our specialized map function will use the |
46 | * helper `apply` function. |
47 | * |
48 | * After defunctionalization, we get: |
49 | * fun apply encoding arg = case encoding of |
50 | * “incr” => incr arg |
51 | * fun map’ F y = case y of |
52 | * Nil => Nil |
53 | * | Cons(x, xs) => Cons(apply F x, map’ F xs) |
54 | * fun addone 1 = map’ “incr” 1 |
55 | * |
56 | * Currently, defunctionalization makes the following assumptions: |
57 | * - functions cannot return function values |
58 | * - function arguments are in two forms: identifier or a lambda abstraction |
59 | * - no functions stored in datatype |
60 | * - functions are not let binded |
61 | */ |
62 | |
63 | #include <tvm/ir/type_functor.h> |
64 | #include <tvm/relay/analysis.h> |
65 | #include <tvm/relay/expr_functor.h> |
66 | #include <tvm/relay/feature.h> |
67 | #include <tvm/relay/transform.h> |
68 | #include <tvm/te/operation.h> |
69 | |
70 | #include "../analysis/type_solver.h" |
71 | #include "../transforms/pass_utils.h" |
72 | namespace tvm { |
73 | namespace relay { |
74 | |
75 | // determine if type contains a FuncType |
76 | bool HasFuncType(const Type& t) { |
77 | struct FuncTypeVisitor : TypeVisitor { |
78 | bool has_func_type; |
79 | FuncTypeVisitor() : has_func_type(false) {} |
80 | |
81 | void VisitType_(const FuncTypeNode* op) { this->has_func_type = true; } |
82 | }; |
83 | |
84 | auto visitor = FuncTypeVisitor(); |
85 | visitor.VisitType(t); |
86 | return visitor.has_func_type; |
87 | } |
88 | // determine if FuncType is a higher order type |
89 | bool IsHigherOrderFunc(const FuncType& t) { |
90 | bool higher_order = false; |
91 | for (auto arg : t->arg_types) { |
92 | higher_order |= HasFuncType(arg); |
93 | } |
94 | return higher_order |= HasFuncType(t->ret_type); |
95 | } |
96 | |
97 | /*! |
98 | * \brief mutator for driving the Defunctionalization transformation |
99 | */ |
100 | class DefuncMutator : public ExprMutator { |
101 | public: |
102 | explicit DefuncMutator(const IRModule& mod) : mod(mod), constructor_counter(0) {} |
103 | |
104 | Expr VisitExpr_(const CallNode* call) { |
105 | if (auto op = call->op.as<GlobalVarNode>()) { |
106 | ICHECK_EQ(call->type_args.size(), op->checked_type().as<FuncTypeNode>()->type_params.size()) |
107 | << "all type args must be explicit" ; |
108 | |
109 | auto op_type = InstFuncType(op->checked_type().as<FuncTypeNode>(), call->type_args); |
110 | ICHECK_EQ(FreeTypeVars(op_type, mod).size(), 0) << "free type vars in instantiated" ; |
111 | ICHECK(!HasFuncType(op_type->ret_type)) << "returning functions not supported" ; |
112 | |
113 | if (!IsHigherOrderFunc(op_type)) { |
114 | // not higher order function |
115 | return ExprMutator::VisitExpr_(call); |
116 | } |
117 | |
118 | // first we encode function arguments |
119 | Array<Expr> args; |
120 | for (size_t i = 0; i < call->args.size(); i++) { |
121 | auto arg = call->args[i]; |
122 | auto type = op_type->arg_types[i]; |
123 | if (!HasFuncType(type)) { |
124 | args.push_back(arg); |
125 | } else { |
126 | args.push_back(EncodeArg(arg, type)); |
127 | } |
128 | } |
129 | auto name = op->name_hint + TypeToString(op_type); |
130 | auto gv = GlobalVar(name); |
131 | if (specialized_gv_map.count(name)) { |
132 | gv = specialized_gv_map[name]; |
133 | } else { |
134 | specialized_gv_map[name] = gv; |
135 | // clone and specialize with specific type |
136 | auto clone = Downcast<Function>(DeDup(mod->Lookup(GetRef<GlobalVar>(op)))); |
137 | auto specialized_function = Specialize(clone, call->type_args); |
138 | // change var types and change all applications to use `apply` method |
139 | auto f = Downcast<Function>(FirstifyVars(specialized_function)); |
140 | mod->Add(gv, f); |
141 | } |
142 | return Call(gv, args); |
143 | } else if (auto op = call->op.as<FunctionNode>()) { |
144 | // reduction by applying vars |
145 | std::unordered_map<Var, Expr, ObjectHash, ObjectEqual> var_binding_map; |
146 | for (size_t i = 0; i < op->params.size(); i++) { |
147 | var_binding_map[op->params[i]] = call->args[i]; |
148 | } |
149 | auto e = Bind(op->body, var_binding_map); |
150 | return this->VisitExpr(e); |
151 | } else if (auto op = call->op.as<VarNode>()) { |
152 | // var node will be encoded as datatype |
153 | // so we need to use the `apply` helper method |
154 | auto var_original_type = GetUnencodedType(op->type_annotation).as<FuncTypeNode>(); |
155 | ICHECK(var_original_type) << "var original type not saved in var_save_type map" ; |
156 | auto op_type = InstFuncType(var_original_type, call->type_args); |
157 | |
158 | Array<Expr> args = {GetRef<Var>(op)}; |
159 | for (auto arg : call->args) { |
160 | args.push_back(this->VisitExpr(arg)); |
161 | } |
162 | |
163 | return Call(GetApplyFunction(op_type), args); |
164 | } |
165 | return ExprMutator::VisitExpr_(call); |
166 | } |
167 | |
168 | private: |
169 | // module |
170 | IRModule mod; |
171 | // gv + str(type) to specialized clone gv |
172 | std::unordered_map<std::string, GlobalVar> specialized_gv_map; |
173 | // str(func_type) to ADT |
174 | std::unordered_map<std::string, GlobalTypeVar> func_encoding; |
175 | // str(func_tyoe) to apply gv |
176 | std::unordered_map<std::string, GlobalVar> apply_map; |
177 | // encoded ADT handle to FuncType |
178 | std::unordered_map<GlobalTypeVar, Type, ObjectHash, StructuralEqual> original_func_type_map; |
179 | // gv to (str(func_type) to constructor encoding) |
180 | std::unordered_map<GlobalVar, std::unordered_map<std::string, Constructor>, ObjectHash, |
181 | ObjectEqual> |
182 | gv_datatype_map; |
183 | // use monotonically increasing integer to represent new constructor_name |
184 | uint64_t constructor_counter; |
185 | |
186 | /*! |
187 | * \brief add a constructor to the GlobalTypeVar, creating a new TypeDef if GlobalTypeVar does not |
188 | * exist |
189 | */ |
190 | void AddConstructor(GlobalTypeVar gtv, Constructor c) { |
191 | if (!mod->ContainGlobalTypeVar(gtv->name_hint)) { |
192 | mod->AddTypeDef(gtv, TypeData(gtv, {}, {c})); |
193 | } else { |
194 | auto typedata = mod->LookupTypeDef(gtv); |
195 | auto constructors = typedata->constructors; |
196 | constructors.push_back(c); |
197 | mod->UpdateTypeDef(gtv, TypeData(typedata->header, typedata->type_vars, constructors)); |
198 | } |
199 | } |
200 | /*! |
201 | * \brief add a case to the apply function, creating the function if it does not exist |
202 | * |
203 | * \param apply_gv GlobalVar of the apply function |
204 | * \param ft is the type functions the apply function handles |
205 | * \param c constructor to add a case for |
206 | * \param expr calls this expr with the args to the apply_gv |
207 | * \param patterns PatterVars to match with the constructor, used for handling free vars in |
208 | * functions |
209 | */ |
210 | void AddApplyCase(GlobalVar apply_gv, FuncType ft, Constructor c, const Expr& expr, |
211 | const Array<Pattern> patterns) { |
212 | ICHECK(c->inputs.size() == patterns.size()) |
213 | << "constructor function and pattern vars have different sizes" ; |
214 | if (!mod->ContainGlobalVar(apply_gv->name_hint)) { |
215 | auto x = Var("x" , TypeCall(c->belong_to, {})); |
216 | auto vars = Array<Var>({x}); |
217 | auto args = Array<Expr>(); |
218 | for (auto t : ft->arg_types) { |
219 | auto y = Var("y" , t); |
220 | vars.push_back(y); |
221 | args.push_back(y); |
222 | } |
223 | |
224 | auto clauses = Array<Clause>({Clause(PatternConstructor(c, patterns), Call(expr, args))}); |
225 | auto body = Match(x, clauses); |
226 | auto f = Function(vars, body, ft->ret_type, {}); |
227 | |
228 | mod->Add(apply_gv, f); |
229 | } else { |
230 | auto f = Downcast<Function>(mod->Lookup(apply_gv)); |
231 | auto body = f->body.as<MatchNode>(); |
232 | ICHECK(body) << "internal invariant broken; apply function body should be a match node" ; |
233 | |
234 | auto clauses = body->clauses; |
235 | auto x = f->params[0]; |
236 | auto args = Array<Expr>(); |
237 | for (size_t i = 1; i < f->params.size(); i++) { |
238 | args.push_back(f->params[i]); |
239 | } |
240 | clauses.push_back(Clause(PatternConstructor(c, patterns), Call(expr, args))); |
241 | |
242 | mod->Add(apply_gv, Function(f->params, Match(x, clauses), f->ret_type, f->type_params), true); |
243 | } |
244 | } |
245 | |
246 | Expr EncodeArg(const Expr& arg, const Type& type) { |
247 | // we assume arg is either an identifier (var or globalvar) or a function |
248 | ICHECK(type.as<FuncTypeNode>()) << "assume no nested functions" ; |
249 | ICHECK(arg.as<VarNode>() || arg.as<GlobalVarNode>() || arg.as<FunctionNode>()) |
250 | << "assume all first-order-parameters are identifiers or functions" ; |
251 | |
252 | if (arg.as<VarNode>()) { |
253 | // variable with functype will be encoded as datatype in surrounding function |
254 | return arg; |
255 | } else if (arg.as<GlobalVarNode>()) { |
256 | return EncodeGlobalVar(Downcast<GlobalVar>(arg), Downcast<FuncType>(type)); |
257 | } else if (auto fn = arg.as<FunctionNode>()) { |
258 | // we handle free vars in anonymous functions by adding arguments to |
259 | // the constructor function |
260 | auto free_vars = FreeVars(arg); |
261 | auto ft = Downcast<FuncType>(type); |
262 | |
263 | auto arg_types = Array<Type>(); |
264 | auto pattern_vars = Array<Pattern>(); |
265 | auto call_args = Array<Expr>(); |
266 | Map<Var, Expr> free_var_bind_map; |
267 | for (auto free_var : free_vars) { |
268 | // free vars are already encoded, can only exist within |
269 | // specialized functions |
270 | if (free_var->type_annotation.defined()) { |
271 | arg_types.push_back(free_var->type_annotation); |
272 | } else { |
273 | arg_types.push_back(free_var->checked_type()); |
274 | } |
275 | auto new_var = Var(free_var->name_hint(), free_var->type_annotation); |
276 | free_var_bind_map.Set(free_var, new_var); |
277 | pattern_vars.push_back(PatternVar(new_var)); |
278 | call_args.push_back(free_var); |
279 | } |
280 | auto gtv = GetFuncEncode(ft); |
281 | auto c = Constructor(std::to_string(++constructor_counter), arg_types, gtv); |
282 | AddConstructor(gtv, c); |
283 | |
284 | auto apply_gv = GetApplyFunction(ft); |
285 | auto body = this->VisitExpr(Bind(fn->body, free_var_bind_map)); |
286 | AddApplyCase(apply_gv, ft, c, WithFields(GetRef<Function>(fn), fn->params, body), |
287 | pattern_vars); |
288 | |
289 | return Call(c, call_args); |
290 | } |
291 | LOG(FATAL) << "EncodeArg failed to cast arg into identifier node or function node" ; |
292 | } |
293 | |
294 | /*! |
295 | * \brief encode a global var with a specialized type with a datatype |
296 | */ |
297 | Expr EncodeGlobalVar(const GlobalVar& gv, const FuncType& ft) { |
298 | auto map = gv_datatype_map[gv]; |
299 | auto type_key = TypeToString(ft); |
300 | if (map.count(type_key) == 0) { |
301 | auto gtv = GetFuncEncode(ft); |
302 | auto c = Constructor(std::to_string(constructor_counter++), {}, gtv); |
303 | map[type_key] = c; |
304 | AddConstructor(gtv, c); |
305 | AddApplyCase(GetApplyFunction(ft), ft, c, gv, {}); |
306 | } |
307 | return Call(map[type_key], {}); |
308 | } |
309 | |
310 | /*! |
311 | * \brief type to string |
312 | */ |
313 | std::string TypeToString(const Type& t) { |
314 | std::ostringstream s; |
315 | s << t->GetTypeKey(); |
316 | return s.str(); |
317 | } |
318 | |
319 | /*! |
320 | * \brief get ADT handle for encoding type t |
321 | */ |
322 | GlobalTypeVar GetFuncEncode(const Type& t) { |
323 | auto adt_name = "Defunc" + TypeToString(t); |
324 | if (func_encoding.count(adt_name) == 0) { |
325 | func_encoding[adt_name] = GlobalTypeVar(adt_name, TypeKind::kAdtHandle); |
326 | } |
327 | original_func_type_map[func_encoding[adt_name]] = t; |
328 | return func_encoding[adt_name]; |
329 | } |
330 | |
331 | /*! |
332 | * \brief get original function type represented by type t |
333 | */ |
334 | FuncType GetUnencodedType(const Type& t) { |
335 | auto tc = t.as<TypeCallNode>(); |
336 | ICHECK(tc) << "expected type call when getting original type from encoded type" ; |
337 | auto gv = tc->func.as<GlobalTypeVarNode>(); |
338 | ICHECK(gv) << "expected global type var in encoded type" ; |
339 | auto type = original_func_type_map[GetRef<GlobalTypeVar>(gv)]; |
340 | ICHECK(type.defined()) << "reverse mapping from encoded type to original type not found" ; |
341 | return Downcast<FuncType>(type); |
342 | } |
343 | |
344 | /*! |
345 | * \brief get the apply function for calling datatypes encoding functions of type t |
346 | */ |
347 | GlobalVar GetApplyFunction(const Type& t) { |
348 | auto f_name = "apply" + TypeToString(t); |
349 | if (apply_map.count(f_name) == 0) { |
350 | apply_map[f_name] = GlobalVar("apply" + TypeToString(t)); |
351 | } |
352 | return apply_map[f_name]; |
353 | } |
354 | |
355 | /*! |
356 | * \brief specialize a function type |
357 | */ |
358 | FuncType InstFuncType(const FuncTypeNode* fty, const Array<Type> type_args) { |
359 | ICHECK(fty) << "InstFuncType functype is null" ; |
360 | ICHECK_EQ(fty->type_params.size(), type_args.size()) |
361 | << "size mismatch between function type params and type args" ; |
362 | auto map = tvm::Map<TypeVar, Type>(); |
363 | for (size_t i = 0; i < type_args.size(); i++) { |
364 | map.Set(fty->type_params[i], type_args[i]); |
365 | } |
366 | // copy with typevars removed |
367 | return Downcast<FuncType>(TypeSubst(FuncType(fty->arg_types, fty->ret_type, {}, {}), map)); |
368 | } |
369 | |
370 | /*! |
371 | * \brief specialize a function expression |
372 | */ |
373 | Function Specialize(const Function& f, const Array<Type> type_args) { |
374 | ICHECK_EQ(f->type_params.size(), type_args.size()) |
375 | << "cannot specialize function with size mismatch between function type params and type " |
376 | "args" ; |
377 | auto map = tvm::Map<TypeVar, Type>(); |
378 | for (size_t i = 0; i < type_args.size(); i++) { |
379 | map.Set(f->type_params[i], type_args[i]); |
380 | } |
381 | // copy with typevars removed |
382 | auto copy = TypeSubst(WithFields(f, {}, {}, {}, /* erase type params */ Array<TypeVar>()), map); |
383 | return Downcast<Function>(copy); |
384 | } |
385 | |
386 | /*! |
387 | * \brief transform a function to be first order by transforming arg_types and |
388 | * using the `apply` function for applications |
389 | */ |
390 | Function FirstifyVars(const Function& f) { |
391 | ICHECK(f->type_params.size() == 0) << "firstify function has type params" ; |
392 | |
393 | tvm::Map<Var, Expr> var_bind_map; |
394 | Array<Var> params; |
395 | for (auto var : f->params) { |
396 | if (auto var_type = var->type_annotation.as<FuncTypeNode>()) { |
397 | // first order parameter |
398 | auto fop_type = GetRef<FuncType>(var_type); |
399 | auto adt = GetFuncEncode(fop_type); |
400 | auto new_var = Var(var->name_hint(), TypeCall(adt, {})); |
401 | mod->LookupTypeDef(adt); |
402 | var_bind_map.Set(var, new_var); |
403 | params.push_back(new_var); |
404 | } else { |
405 | ICHECK(!HasFuncType(var->type_annotation)) |
406 | << "nested function type in parameter not supported yet" ; |
407 | params.push_back(var); |
408 | } |
409 | } |
410 | |
411 | auto bind = Downcast<Function>(Bind(f, var_bind_map)); |
412 | return WithFields(bind, params, this->VisitExpr(bind->body), bind->ret_type, |
413 | /* erase type params */ Array<TypeVar>()); |
414 | } |
415 | }; |
416 | |
417 | Expr Defunctionalization(const Function& f, const IRModule& mod) { |
418 | // f is the starting point of the program, all types MUST be known |
419 | ICHECK(f->type_params.size() == 0) << "no polymorphism supported for defunctionalization" ; |
420 | for (const auto& p : f->params) { |
421 | ICHECK(!HasFuncType(p->checked_type())) << "program cannot have func type parameters" ; |
422 | } |
423 | ICHECK(!HasFuncType(f->ret_type)) << "return type cannot contain function" ; |
424 | |
425 | return Downcast<Function>(DefuncMutator(mod).VisitExpr(f)); |
426 | } |
427 | |
428 | TVM_REGISTER_GLOBAL("relay._transform.Defunctionalization" ).set_body_typed(Defunctionalization); |
429 | |
430 | } // namespace relay |
431 | } // namespace tvm |
432 | |