1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * |
22 | * \file to_cps.cc |
23 | * |
24 | * \brief Turn a program to continuation passing style. |
25 | * |
26 | * Given a fresh type variable 'answer', |
27 | * continuation passing style(CPS) convert every function of a -> b to a -> (b -> anwer) -> answer. |
28 | * |
29 | * That is, instead of returning the result directly, |
30 | * function will now call another function (called the continuation) |
31 | * and return that value as a result instead. |
32 | * |
33 | * Continuation passing style turn all function call into tail call, |
34 | * which bound the stack size, prevent stack from overflowing during recursion, |
35 | * and allow tail call optimization. |
36 | * |
37 | * In relay, as tensor operation is the bottleneck, |
38 | * CPS is currently intended to transform the program before partial eval (PE), |
39 | * as it reify the control flow and enable PE to handle control flow join more aggressively. |
40 | * |
41 | * For example, given 'let a = if b then c else d in e', it will transform the code into |
42 | * 'let f a = e in if b then f c else f d'. |
43 | * This allow f to be optimized individually in both branch. |
44 | * |
45 | * We implement CPS conversion by higher order transform |
46 | * (see http://matt.might.net/articles/cps-conversion/). |
47 | * The basic idea is that we will recursively traverse the AST. |
48 | * During the traversal, there is an extra parameter, mcont, of expr -> expr. |
49 | * It is basically a continuation at the metalevel. |
50 | * All cases in the transform must return via the mcont, |
51 | * wheter directly invoking it, or indirectly by recursion. |
52 | */ |
53 | #include <tvm/ir/type_functor.h> |
54 | #include <tvm/relay/expr_functor.h> |
55 | #include <tvm/relay/feature.h> |
56 | #include <tvm/relay/pattern_functor.h> |
57 | #include <tvm/relay/transform.h> |
58 | |
59 | #include "let_list.h" |
60 | #include "pass_utils.h" |
61 | |
62 | namespace tvm { |
63 | namespace relay { |
64 | |
65 | // we assume the data type has no closure - no idea how to look into datatype right now. |
66 | |
67 | Type Arrow(const Type& l, const Type& r) { return FuncType({l}, r, {}, {}); } |
68 | |
69 | Type CPSType(const Type& t, const TypeVar& answer); |
70 | |
71 | FuncType CPSFuncType(const FuncType& f, const TypeVar& answer) { |
72 | tvm::Array<Type> new_arg_types; |
73 | for (const Type& t : f->arg_types) { |
74 | new_arg_types.push_back(CPSType(t, answer)); |
75 | } |
76 | new_arg_types.push_back(Arrow(CPSType(f->ret_type, answer), answer)); |
77 | return FuncType(new_arg_types, answer, f->type_params, f->type_constraints); |
78 | } |
79 | |
80 | Type CPSType(const Type& t, const TypeVar& answer) { |
81 | struct CPSTypeMutator : TypeMutator { |
82 | explicit CPSTypeMutator(const TypeVar& answer) : answer(answer) {} |
83 | TypeVar answer; |
84 | Type VisitType_(const FuncTypeNode* t) final { |
85 | return CPSFuncType(GetRef<FuncType>(t), answer); |
86 | } |
87 | } mut(answer); |
88 | return mut(t); |
89 | } |
90 | |
91 | // transform global functions into cps form. |
92 | using CPSMap = std::unordered_map<GlobalVar, GlobalVar, ObjectPtrHash, ObjectPtrEqual>; |
93 | |
94 | // transform vars from the original program into new vars, so their type will be correct. |
95 | using VarMap = std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual>; |
96 | |
97 | /* |
98 | * The meta continuation. |
99 | * There is 3 rules on the metacontinuation: |
100 | * 0: It can only use the argument once. |
101 | * The argument is code, and using it twice will duplicate code. |
102 | * Bound the argument via let instead. |
103 | * 1: If the size of the metacontinuation is unbounded, it can only be called once. |
104 | * It contain code, so calling it twice duplicate code. |
105 | * Reify the continuation and bound it instead. |
106 | * See the function 'reify' and the if case for more detail. |
107 | * 2: The argument must be effect free. |
108 | * It might reorder or drop the argument. |
109 | * Again, bound the argument via let instead. |
110 | * See the call case for more detail. |
111 | */ |
112 | using MCont = std::function<Expr(const Expr&)>; |
113 | |
114 | Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm); |
115 | |
116 | Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm, VarMap* vm, |
117 | const TypeVar& answer) { |
118 | std::function<Var(Var)> remap = [&](const Var& v) { return vm->count(v) == 0 ? v : vm->at(v); }; |
119 | auto function_type = Downcast<FuncType>(f->checked_type()); |
120 | // Each MCont can be used at most once. |
121 | struct CPSFunctor : ExprFunctor<Expr(const Expr&, const MCont&)>, PatternMutator { |
122 | CPSFunctor(const std::function<Var(Var)>& remap, const TypeVar& answer, const IRModule& m, |
123 | VarMap* vm, CPSMap* cm) |
124 | : remap(remap), answer(answer), m(m), vm(vm), cm(cm) {} |
125 | const std::function<Var(Var)>& remap; |
126 | TypeVar answer; |
127 | IRModule m; |
128 | VarMap* vm; |
129 | CPSMap* cm; |
130 | |
131 | Expr VisitExpr_(const LetNode* op, const MCont& k) final { |
132 | return VisitExpr( |
133 | op->value, [&](const Expr& v) { return Let(remap(op->var), v, VisitExpr(op->body, k)); }); |
134 | } |
135 | |
136 | Expr VisitExpr_(const FunctionNode* op, const MCont& k) final { |
137 | ICHECK(!op->HasNonzeroAttr(attr::kPrimitive)) << "primitive func not supported yet." ; |
138 | return k(ToCPS(GetRef<Function>(op), m, cm, vm, answer)); |
139 | } |
140 | |
141 | Expr VisitExpr_(const ConstantNode* op, const MCont& k) final { |
142 | return k(GetRef<Constant>(op)); |
143 | } |
144 | |
145 | Expr VisitExpr_(const VarNode* op, const MCont& k) final { return k(remap(GetRef<Var>(op))); } |
146 | |
147 | Pattern VisitPattern_(const PatternVarNode* op) final { return PatternVar(remap(op->var)); } |
148 | |
149 | Expr VisitExpr_(const GlobalVarNode* op, const MCont& k) final { |
150 | auto gv = GetRef<GlobalVar>(op); |
151 | if (cm->count(gv) == 0) { |
152 | // only look unfold non-external calls. |
153 | BaseFunc base_func = m->Lookup(gv); |
154 | if (auto* n = base_func.as<FunctionNode>()) { |
155 | auto cps_gv = GlobalVar(std::string(gv->name_hint) + "_cps" ); |
156 | cm->insert({gv, cps_gv}); |
157 | m->Add(cps_gv, ToCPS(GetRef<Function>(n), m, cm)); |
158 | } else { |
159 | // return the original global var if it is |
160 | // an external call to non-relay function. |
161 | return GetRef<GlobalVar>(op); |
162 | } |
163 | } |
164 | return k(cm->at(gv)); |
165 | } |
166 | |
167 | Expr VisitExpr_(const RefCreateNode* op, const MCont& k) final { |
168 | return VisitExpr(op->value, [&](const Expr& v) { return k(RefCreate(v)); }); |
169 | } |
170 | |
171 | Expr reify(const MCont& k) { |
172 | Var arg = Var("arg" , Type()); |
173 | return Function({arg}, k(arg), Type(), {}, {}); |
174 | } |
175 | |
176 | Expr reify(const MCont& k, const std::function<Expr(MCont)>& cont) { |
177 | return LetList::LetBind(reify(k), [&](const Var& f) { |
178 | return cont([&](const Expr& e) { return Call(f, {e}); }); |
179 | }); |
180 | } |
181 | |
182 | Expr VisitExpr_(const IfNode* op, const MCont& k) final { |
183 | return reify(k, [&](const MCont& kf) { |
184 | return VisitExpr(op->cond, [&](const Expr& v) { |
185 | return If(v, VisitExpr(op->true_branch, kf), VisitExpr(op->false_branch, kf)); |
186 | }); |
187 | }); |
188 | } |
189 | |
190 | Expr VisitExpr_(const MatchNode* op, const MCont& k) final { |
191 | return reify(k, [&](const MCont& kf) { |
192 | return VisitExpr(op->data, [&](const Expr& v) { |
193 | tvm::Array<Clause> clauses; |
194 | for (const auto& c : op->clauses) { |
195 | clauses.push_back(Clause(VisitPattern(c->lhs), VisitExpr(c->rhs, kf))); |
196 | } |
197 | return Match(v, clauses, op->complete); |
198 | }); |
199 | }); |
200 | } |
201 | |
202 | Expr VisitExpr_(const RefReadNode* op, const MCont& k) final { |
203 | return VisitExpr(op->ref, [&](const Expr& r) { return LetList::LetBind(RefRead(r), k); }); |
204 | } |
205 | |
206 | Expr VisitExpr_(const RefWriteNode* op, const MCont& k) final { |
207 | return VisitExpr(op->ref, [&](const Expr& r) { |
208 | return VisitExpr(op->value, |
209 | [&](const Expr& v) { return LetList::LetBind(RefWrite(r, v), k); }); |
210 | }); |
211 | } |
212 | |
213 | Expr VisitExpr_(const TupleNode* tuple_node, const MCont& k) final { |
214 | tvm::Array<Expr> fields; |
215 | fields.reserve(tuple_node->fields.size()); |
216 | std::function<Expr()> next; |
217 | next = [&]() { |
218 | return (fields.size() == tuple_node->fields.size()) |
219 | ? k(WithFields(GetRef<Tuple>(tuple_node), fields)) |
220 | : VisitExpr(tuple_node->fields[fields.size()], [&](const Expr& v) { |
221 | fields.push_back(v); |
222 | return next(); |
223 | }); |
224 | }; |
225 | return next(); |
226 | } |
227 | |
228 | Expr VisitExpr_(const TupleGetItemNode* op, const MCont& k) final { |
229 | return VisitExpr(op->tuple, [&](const Expr& v) { return k(TupleGetItem(v, op->index)); }); |
230 | } |
231 | |
232 | Expr VisitExpr_(const CallNode* op, const MCont& k) final { |
233 | if (op->op.as<OpNode>() || op->op.as<ConstructorNode>()) { |
234 | tvm::Array<Expr> args; |
235 | std::function<Expr()> next; |
236 | next = [&]() { |
237 | if (args.size() == op->args.size()) { |
238 | return LetList::LetBind(Call(op->op, args, op->attrs, op->type_args), k); |
239 | } else { |
240 | return VisitExpr(op->args[args.size()], [&](const Expr& v) { |
241 | args.push_back(v); |
242 | return next(); |
243 | }); |
244 | } |
245 | }; |
246 | return next(); |
247 | } else { |
248 | Expr f; |
249 | tvm::Array<Expr> args; |
250 | std::function<Expr()> next; |
251 | next = [&]() { |
252 | if (args.size() == op->args.size()) { |
253 | args.push_back(reify(k)); |
254 | return Expr(Call(f, args, op->attrs, op->type_args)); |
255 | } else { |
256 | return VisitExpr(op->args[args.size()], [&](const Expr& v) { |
257 | args.push_back(v); |
258 | return next(); |
259 | }); |
260 | } |
261 | }; |
262 | return VisitExpr(op->op, [&](const Expr& v) { |
263 | f = v; |
264 | return next(); |
265 | }); |
266 | } |
267 | } |
268 | } mut(remap, answer, m, vm, cm); |
269 | Var k = Var("k" , Arrow(CPSType(function_type->ret_type, answer), answer)); |
270 | tvm::Array<Var> new_params; |
271 | for (const Var& v : f->params) { |
272 | new_params.push_back(remap(v)); |
273 | } |
274 | new_params.push_back(k); |
275 | return WithFields(f, new_params, |
276 | mut.VisitExpr(f->body, [&](const Expr& e) { return Call(k, {e}); }), answer); |
277 | } |
278 | |
279 | Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm) { |
280 | TypeVar answer = TypeVar("answer" , kType); |
281 | VarMap var; |
282 | struct Remapper : ExprVisitor, PatternVisitor { |
283 | Remapper(const TypeVar& answer, VarMap* vm) : answer(answer), vm(vm) {} |
284 | TypeVar answer; |
285 | VarMap* vm; |
286 | void VisitExpr_(const VarNode* vn) final { |
287 | Var v = GetRef<Var>(vn); |
288 | if (vm->count(v) == 0) { |
289 | auto ret = Var(v->name_hint(), CPSType(v->checked_type(), answer)); |
290 | vm->insert({v, ret}); |
291 | } |
292 | } |
293 | |
294 | void VisitPattern(const Pattern& p) final { PatternVisitor::VisitPattern(p); } |
295 | |
296 | void VisitPattern_(const PatternVarNode* op) final { VisitExpr(op->var); } |
297 | } remap(answer, &var); |
298 | remap.VisitExpr(f); |
299 | Function ret = ToCPS(f, m, cm, &var, answer); |
300 | auto new_type_params = ret->type_params; |
301 | new_type_params.push_back(answer); |
302 | return WithFields(ret, ret->params, ret->body, ret->ret_type, new_type_params); |
303 | } |
304 | |
305 | Function ToCPS(const Function& f, const IRModule& m) { |
306 | CheckFeature(f, m, FeatureSet::All() - fGraph); |
307 | CPSMap cps; |
308 | return ToCPS(f, m, &cps); |
309 | } |
310 | |
311 | Function UnCPS(const Function& f) { |
312 | CheckFeature(f, FeatureSet::All() - fGraph); |
313 | ICHECK_GT(f->params.size(), 0); |
314 | Array<Var> new_params; |
315 | for (const auto& p : f->params) { |
316 | new_params.push_back(Var(p->name_hint(), p->checked_type())); |
317 | } |
318 | auto cont_type = Downcast<FuncType>(new_params.back()->type_annotation); |
319 | new_params.pop_back(); |
320 | ICHECK_EQ(cont_type->arg_types.size(), 1); |
321 | auto new_ret_type = Type(cont_type->arg_types[0]); |
322 | Array<TypeVar> new_type_params; |
323 | for (const auto& tp : f->type_params) { |
324 | new_type_params.push_back(TypeVar(tp->name_hint, tp->kind)); |
325 | } |
326 | auto answer_type = new_type_params.back(); |
327 | new_type_params.pop_back(); |
328 | // TODO(@M.K.): make alphaequal work on free term |
329 | // ICHECK(tvm::StructuralEqual()(cont_type, Arrow(new_ret_type, answer_type))); |
330 | auto x = Var("x" , new_ret_type); |
331 | auto cont = Function({x}, x, new_ret_type, {}, {}); |
332 | tvm::Array<Expr> args; |
333 | for (const auto& p : new_params) { |
334 | args.push_back(p); |
335 | } |
336 | args.push_back(cont); |
337 | tvm::Array<Type> type_args; |
338 | for (const auto& tp : new_type_params) { |
339 | type_args.push_back(tp); |
340 | } |
341 | type_args.push_back(new_ret_type); |
342 | return WithFields(f, new_params, Call(f, args, {}, type_args), new_ret_type, new_type_params); |
343 | } |
344 | |
345 | TVM_REGISTER_GLOBAL("relay._transform.to_cps" ) |
346 | .set_body_typed(static_cast<Function (*)(const Function&, const IRModule&)>(ToCPS)); |
347 | |
348 | TVM_REGISTER_GLOBAL("relay._transform.un_cps" ).set_body_typed(UnCPS); |
349 | |
350 | namespace transform { |
351 | |
352 | Pass ToCPS() { |
353 | runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = |
354 | [=](Function f, IRModule m, PassContext pc) { return Function(ToCPS(f, m)); }; |
355 | return CreateFunctionPass(pass_func, 1, "ToCPS" , {}); |
356 | } |
357 | |
358 | TVM_REGISTER_GLOBAL("relay._transform.ToCPS" ).set_body_typed(ToCPS); |
359 | |
360 | Pass UnCPS() { |
361 | runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = |
362 | [=](Function f, IRModule m, PassContext pc) { return Function(UnCPS(f)); }; |
363 | return CreateFunctionPass(pass_func, 1, "UnCPS" , {}); |
364 | } |
365 | |
366 | TVM_REGISTER_GLOBAL("relay._transform.UnCPS" ).set_body_typed(UnCPS); |
367 | |
368 | } // namespace transform |
369 | |
370 | } // namespace relay |
371 | } // namespace tvm |
372 | |