1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 *
22 * \file defunctionalization.cc
23 *
24 * \brief Defunctionalization for Relay IR
25 *
26 * This pass transforms a higher-order program into a first-order program with defunctionalization.
27 * This means that all higher order functions (i.e functions that take function arguments or return
28 * functions) should be transformed into a semantically equivalent first order one.
29 *
30 * This pass implements a basic typed defunctionalization method.
31 * All higher order functions are cloned and specialized (so that there are no type params).
32 * Function type arguments are encoded as datatypes and a helper `apply` function is used
33 * to "call" them.
34 *
35 * For example, take the following higher order program:
36 * fun map F y = case y of
37 * Nil => Nil
38 * | Cons(x, XS) => Cons(F z, map F XS)
39 * fun addone 1 = map (\x -> \x + 1) 1
40 *
41 * where `addone` is our program.
42 * When we call the `map` function, we see that it is a higher-order function,
43 * but we can clone `map ` function and specialize it with the type_params of the call.
44 * In addition, our function argument `(\x -> \x + 1)` will be encoded as a datatype constructor,
45 * which we will call `incr`, and all calls to `F` in our specialized map function will use the
46 * helper `apply` function.
47 *
48 * After defunctionalization, we get:
49 * fun apply encoding arg = case encoding of
50 * “incr” => incr arg
51 * fun map’ F y = case y of
52 * Nil => Nil
53 * | Cons(x, xs) => Cons(apply F x, map’ F xs)
54 * fun addone 1 = map’ “incr” 1
55 *
56 * Currently, defunctionalization makes the following assumptions:
57 * - functions cannot return function values
58 * - function arguments are in two forms: identifier or a lambda abstraction
59 * - no functions stored in datatype
60 * - functions are not let binded
61 */
62
63#include <tvm/ir/type_functor.h>
64#include <tvm/relay/analysis.h>
65#include <tvm/relay/expr_functor.h>
66#include <tvm/relay/feature.h>
67#include <tvm/relay/transform.h>
68#include <tvm/te/operation.h>
69
70#include "../analysis/type_solver.h"
71#include "../transforms/pass_utils.h"
72namespace tvm {
73namespace relay {
74
75// determine if type contains a FuncType
76bool HasFuncType(const Type& t) {
77 struct FuncTypeVisitor : TypeVisitor {
78 bool has_func_type;
79 FuncTypeVisitor() : has_func_type(false) {}
80
81 void VisitType_(const FuncTypeNode* op) { this->has_func_type = true; }
82 };
83
84 auto visitor = FuncTypeVisitor();
85 visitor.VisitType(t);
86 return visitor.has_func_type;
87}
88// determine if FuncType is a higher order type
89bool IsHigherOrderFunc(const FuncType& t) {
90 bool higher_order = false;
91 for (auto arg : t->arg_types) {
92 higher_order |= HasFuncType(arg);
93 }
94 return higher_order |= HasFuncType(t->ret_type);
95}
96
97/*!
98 * \brief mutator for driving the Defunctionalization transformation
99 */
100class DefuncMutator : public ExprMutator {
101 public:
102 explicit DefuncMutator(const IRModule& mod) : mod(mod), constructor_counter(0) {}
103
104 Expr VisitExpr_(const CallNode* call) {
105 if (auto op = call->op.as<GlobalVarNode>()) {
106 ICHECK_EQ(call->type_args.size(), op->checked_type().as<FuncTypeNode>()->type_params.size())
107 << "all type args must be explicit";
108
109 auto op_type = InstFuncType(op->checked_type().as<FuncTypeNode>(), call->type_args);
110 ICHECK_EQ(FreeTypeVars(op_type, mod).size(), 0) << "free type vars in instantiated";
111 ICHECK(!HasFuncType(op_type->ret_type)) << "returning functions not supported";
112
113 if (!IsHigherOrderFunc(op_type)) {
114 // not higher order function
115 return ExprMutator::VisitExpr_(call);
116 }
117
118 // first we encode function arguments
119 Array<Expr> args;
120 for (size_t i = 0; i < call->args.size(); i++) {
121 auto arg = call->args[i];
122 auto type = op_type->arg_types[i];
123 if (!HasFuncType(type)) {
124 args.push_back(arg);
125 } else {
126 args.push_back(EncodeArg(arg, type));
127 }
128 }
129 auto name = op->name_hint + TypeToString(op_type);
130 auto gv = GlobalVar(name);
131 if (specialized_gv_map.count(name)) {
132 gv = specialized_gv_map[name];
133 } else {
134 specialized_gv_map[name] = gv;
135 // clone and specialize with specific type
136 auto clone = Downcast<Function>(DeDup(mod->Lookup(GetRef<GlobalVar>(op))));
137 auto specialized_function = Specialize(clone, call->type_args);
138 // change var types and change all applications to use `apply` method
139 auto f = Downcast<Function>(FirstifyVars(specialized_function));
140 mod->Add(gv, f);
141 }
142 return Call(gv, args);
143 } else if (auto op = call->op.as<FunctionNode>()) {
144 // reduction by applying vars
145 std::unordered_map<Var, Expr, ObjectHash, ObjectEqual> var_binding_map;
146 for (size_t i = 0; i < op->params.size(); i++) {
147 var_binding_map[op->params[i]] = call->args[i];
148 }
149 auto e = Bind(op->body, var_binding_map);
150 return this->VisitExpr(e);
151 } else if (auto op = call->op.as<VarNode>()) {
152 // var node will be encoded as datatype
153 // so we need to use the `apply` helper method
154 auto var_original_type = GetUnencodedType(op->type_annotation).as<FuncTypeNode>();
155 ICHECK(var_original_type) << "var original type not saved in var_save_type map";
156 auto op_type = InstFuncType(var_original_type, call->type_args);
157
158 Array<Expr> args = {GetRef<Var>(op)};
159 for (auto arg : call->args) {
160 args.push_back(this->VisitExpr(arg));
161 }
162
163 return Call(GetApplyFunction(op_type), args);
164 }
165 return ExprMutator::VisitExpr_(call);
166 }
167
168 private:
169 // module
170 IRModule mod;
171 // gv + str(type) to specialized clone gv
172 std::unordered_map<std::string, GlobalVar> specialized_gv_map;
173 // str(func_type) to ADT
174 std::unordered_map<std::string, GlobalTypeVar> func_encoding;
175 // str(func_tyoe) to apply gv
176 std::unordered_map<std::string, GlobalVar> apply_map;
177 // encoded ADT handle to FuncType
178 std::unordered_map<GlobalTypeVar, Type, ObjectHash, StructuralEqual> original_func_type_map;
179 // gv to (str(func_type) to constructor encoding)
180 std::unordered_map<GlobalVar, std::unordered_map<std::string, Constructor>, ObjectHash,
181 ObjectEqual>
182 gv_datatype_map;
183 // use monotonically increasing integer to represent new constructor_name
184 uint64_t constructor_counter;
185
186 /*!
187 * \brief add a constructor to the GlobalTypeVar, creating a new TypeDef if GlobalTypeVar does not
188 * exist
189 */
190 void AddConstructor(GlobalTypeVar gtv, Constructor c) {
191 if (!mod->ContainGlobalTypeVar(gtv->name_hint)) {
192 mod->AddTypeDef(gtv, TypeData(gtv, {}, {c}));
193 } else {
194 auto typedata = mod->LookupTypeDef(gtv);
195 auto constructors = typedata->constructors;
196 constructors.push_back(c);
197 mod->UpdateTypeDef(gtv, TypeData(typedata->header, typedata->type_vars, constructors));
198 }
199 }
200 /*!
201 * \brief add a case to the apply function, creating the function if it does not exist
202 *
203 * \param apply_gv GlobalVar of the apply function
204 * \param ft is the type functions the apply function handles
205 * \param c constructor to add a case for
206 * \param expr calls this expr with the args to the apply_gv
207 * \param patterns PatterVars to match with the constructor, used for handling free vars in
208 * functions
209 */
210 void AddApplyCase(GlobalVar apply_gv, FuncType ft, Constructor c, const Expr& expr,
211 const Array<Pattern> patterns) {
212 ICHECK(c->inputs.size() == patterns.size())
213 << "constructor function and pattern vars have different sizes";
214 if (!mod->ContainGlobalVar(apply_gv->name_hint)) {
215 auto x = Var("x", TypeCall(c->belong_to, {}));
216 auto vars = Array<Var>({x});
217 auto args = Array<Expr>();
218 for (auto t : ft->arg_types) {
219 auto y = Var("y", t);
220 vars.push_back(y);
221 args.push_back(y);
222 }
223
224 auto clauses = Array<Clause>({Clause(PatternConstructor(c, patterns), Call(expr, args))});
225 auto body = Match(x, clauses);
226 auto f = Function(vars, body, ft->ret_type, {});
227
228 mod->Add(apply_gv, f);
229 } else {
230 auto f = Downcast<Function>(mod->Lookup(apply_gv));
231 auto body = f->body.as<MatchNode>();
232 ICHECK(body) << "internal invariant broken; apply function body should be a match node";
233
234 auto clauses = body->clauses;
235 auto x = f->params[0];
236 auto args = Array<Expr>();
237 for (size_t i = 1; i < f->params.size(); i++) {
238 args.push_back(f->params[i]);
239 }
240 clauses.push_back(Clause(PatternConstructor(c, patterns), Call(expr, args)));
241
242 mod->Add(apply_gv, Function(f->params, Match(x, clauses), f->ret_type, f->type_params), true);
243 }
244 }
245
246 Expr EncodeArg(const Expr& arg, const Type& type) {
247 // we assume arg is either an identifier (var or globalvar) or a function
248 ICHECK(type.as<FuncTypeNode>()) << "assume no nested functions";
249 ICHECK(arg.as<VarNode>() || arg.as<GlobalVarNode>() || arg.as<FunctionNode>())
250 << "assume all first-order-parameters are identifiers or functions";
251
252 if (arg.as<VarNode>()) {
253 // variable with functype will be encoded as datatype in surrounding function
254 return arg;
255 } else if (arg.as<GlobalVarNode>()) {
256 return EncodeGlobalVar(Downcast<GlobalVar>(arg), Downcast<FuncType>(type));
257 } else if (auto fn = arg.as<FunctionNode>()) {
258 // we handle free vars in anonymous functions by adding arguments to
259 // the constructor function
260 auto free_vars = FreeVars(arg);
261 auto ft = Downcast<FuncType>(type);
262
263 auto arg_types = Array<Type>();
264 auto pattern_vars = Array<Pattern>();
265 auto call_args = Array<Expr>();
266 Map<Var, Expr> free_var_bind_map;
267 for (auto free_var : free_vars) {
268 // free vars are already encoded, can only exist within
269 // specialized functions
270 if (free_var->type_annotation.defined()) {
271 arg_types.push_back(free_var->type_annotation);
272 } else {
273 arg_types.push_back(free_var->checked_type());
274 }
275 auto new_var = Var(free_var->name_hint(), free_var->type_annotation);
276 free_var_bind_map.Set(free_var, new_var);
277 pattern_vars.push_back(PatternVar(new_var));
278 call_args.push_back(free_var);
279 }
280 auto gtv = GetFuncEncode(ft);
281 auto c = Constructor(std::to_string(++constructor_counter), arg_types, gtv);
282 AddConstructor(gtv, c);
283
284 auto apply_gv = GetApplyFunction(ft);
285 auto body = this->VisitExpr(Bind(fn->body, free_var_bind_map));
286 AddApplyCase(apply_gv, ft, c, WithFields(GetRef<Function>(fn), fn->params, body),
287 pattern_vars);
288
289 return Call(c, call_args);
290 }
291 LOG(FATAL) << "EncodeArg failed to cast arg into identifier node or function node";
292 }
293
294 /*!
295 * \brief encode a global var with a specialized type with a datatype
296 */
297 Expr EncodeGlobalVar(const GlobalVar& gv, const FuncType& ft) {
298 auto map = gv_datatype_map[gv];
299 auto type_key = TypeToString(ft);
300 if (map.count(type_key) == 0) {
301 auto gtv = GetFuncEncode(ft);
302 auto c = Constructor(std::to_string(constructor_counter++), {}, gtv);
303 map[type_key] = c;
304 AddConstructor(gtv, c);
305 AddApplyCase(GetApplyFunction(ft), ft, c, gv, {});
306 }
307 return Call(map[type_key], {});
308 }
309
310 /*!
311 * \brief type to string
312 */
313 std::string TypeToString(const Type& t) {
314 std::ostringstream s;
315 s << t->GetTypeKey();
316 return s.str();
317 }
318
319 /*!
320 * \brief get ADT handle for encoding type t
321 */
322 GlobalTypeVar GetFuncEncode(const Type& t) {
323 auto adt_name = "Defunc" + TypeToString(t);
324 if (func_encoding.count(adt_name) == 0) {
325 func_encoding[adt_name] = GlobalTypeVar(adt_name, TypeKind::kAdtHandle);
326 }
327 original_func_type_map[func_encoding[adt_name]] = t;
328 return func_encoding[adt_name];
329 }
330
331 /*!
332 * \brief get original function type represented by type t
333 */
334 FuncType GetUnencodedType(const Type& t) {
335 auto tc = t.as<TypeCallNode>();
336 ICHECK(tc) << "expected type call when getting original type from encoded type";
337 auto gv = tc->func.as<GlobalTypeVarNode>();
338 ICHECK(gv) << "expected global type var in encoded type";
339 auto type = original_func_type_map[GetRef<GlobalTypeVar>(gv)];
340 ICHECK(type.defined()) << "reverse mapping from encoded type to original type not found";
341 return Downcast<FuncType>(type);
342 }
343
344 /*!
345 * \brief get the apply function for calling datatypes encoding functions of type t
346 */
347 GlobalVar GetApplyFunction(const Type& t) {
348 auto f_name = "apply" + TypeToString(t);
349 if (apply_map.count(f_name) == 0) {
350 apply_map[f_name] = GlobalVar("apply" + TypeToString(t));
351 }
352 return apply_map[f_name];
353 }
354
355 /*!
356 * \brief specialize a function type
357 */
358 FuncType InstFuncType(const FuncTypeNode* fty, const Array<Type> type_args) {
359 ICHECK(fty) << "InstFuncType functype is null";
360 ICHECK_EQ(fty->type_params.size(), type_args.size())
361 << "size mismatch between function type params and type args";
362 auto map = tvm::Map<TypeVar, Type>();
363 for (size_t i = 0; i < type_args.size(); i++) {
364 map.Set(fty->type_params[i], type_args[i]);
365 }
366 // copy with typevars removed
367 return Downcast<FuncType>(TypeSubst(FuncType(fty->arg_types, fty->ret_type, {}, {}), map));
368 }
369
370 /*!
371 * \brief specialize a function expression
372 */
373 Function Specialize(const Function& f, const Array<Type> type_args) {
374 ICHECK_EQ(f->type_params.size(), type_args.size())
375 << "cannot specialize function with size mismatch between function type params and type "
376 "args";
377 auto map = tvm::Map<TypeVar, Type>();
378 for (size_t i = 0; i < type_args.size(); i++) {
379 map.Set(f->type_params[i], type_args[i]);
380 }
381 // copy with typevars removed
382 auto copy = TypeSubst(WithFields(f, {}, {}, {}, /* erase type params */ Array<TypeVar>()), map);
383 return Downcast<Function>(copy);
384 }
385
386 /*!
387 * \brief transform a function to be first order by transforming arg_types and
388 * using the `apply` function for applications
389 */
390 Function FirstifyVars(const Function& f) {
391 ICHECK(f->type_params.size() == 0) << "firstify function has type params";
392
393 tvm::Map<Var, Expr> var_bind_map;
394 Array<Var> params;
395 for (auto var : f->params) {
396 if (auto var_type = var->type_annotation.as<FuncTypeNode>()) {
397 // first order parameter
398 auto fop_type = GetRef<FuncType>(var_type);
399 auto adt = GetFuncEncode(fop_type);
400 auto new_var = Var(var->name_hint(), TypeCall(adt, {}));
401 mod->LookupTypeDef(adt);
402 var_bind_map.Set(var, new_var);
403 params.push_back(new_var);
404 } else {
405 ICHECK(!HasFuncType(var->type_annotation))
406 << "nested function type in parameter not supported yet";
407 params.push_back(var);
408 }
409 }
410
411 auto bind = Downcast<Function>(Bind(f, var_bind_map));
412 return WithFields(bind, params, this->VisitExpr(bind->body), bind->ret_type,
413 /* erase type params */ Array<TypeVar>());
414 }
415};
416
417Expr Defunctionalization(const Function& f, const IRModule& mod) {
418 // f is the starting point of the program, all types MUST be known
419 ICHECK(f->type_params.size() == 0) << "no polymorphism supported for defunctionalization";
420 for (const auto& p : f->params) {
421 ICHECK(!HasFuncType(p->checked_type())) << "program cannot have func type parameters";
422 }
423 ICHECK(!HasFuncType(f->ret_type)) << "return type cannot contain function";
424
425 return Downcast<Function>(DefuncMutator(mod).VisitExpr(f));
426}
427
428TVM_REGISTER_GLOBAL("relay._transform.Defunctionalization").set_body_typed(Defunctionalization);
429
430} // namespace relay
431} // namespace tvm
432