1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 *
22 * \file to_cps.cc
23 *
24 * \brief Turn a program to continuation passing style.
25 *
26 * Given a fresh type variable 'answer',
27 * continuation passing style(CPS) convert every function of a -> b to a -> (b -> anwer) -> answer.
28 *
29 * That is, instead of returning the result directly,
30 * function will now call another function (called the continuation)
31 * and return that value as a result instead.
32 *
33 * Continuation passing style turn all function call into tail call,
34 * which bound the stack size, prevent stack from overflowing during recursion,
35 * and allow tail call optimization.
36 *
37 * In relay, as tensor operation is the bottleneck,
38 * CPS is currently intended to transform the program before partial eval (PE),
39 * as it reify the control flow and enable PE to handle control flow join more aggressively.
40 *
41 * For example, given 'let a = if b then c else d in e', it will transform the code into
42 * 'let f a = e in if b then f c else f d'.
43 * This allow f to be optimized individually in both branch.
44 *
45 * We implement CPS conversion by higher order transform
46 * (see http://matt.might.net/articles/cps-conversion/).
47 * The basic idea is that we will recursively traverse the AST.
48 * During the traversal, there is an extra parameter, mcont, of expr -> expr.
49 * It is basically a continuation at the metalevel.
50 * All cases in the transform must return via the mcont,
51 * wheter directly invoking it, or indirectly by recursion.
52 */
53#include <tvm/ir/type_functor.h>
54#include <tvm/relay/expr_functor.h>
55#include <tvm/relay/feature.h>
56#include <tvm/relay/pattern_functor.h>
57#include <tvm/relay/transform.h>
58
59#include "let_list.h"
60#include "pass_utils.h"
61
62namespace tvm {
63namespace relay {
64
65// we assume the data type has no closure - no idea how to look into datatype right now.
66
67Type Arrow(const Type& l, const Type& r) { return FuncType({l}, r, {}, {}); }
68
69Type CPSType(const Type& t, const TypeVar& answer);
70
71FuncType CPSFuncType(const FuncType& f, const TypeVar& answer) {
72 tvm::Array<Type> new_arg_types;
73 for (const Type& t : f->arg_types) {
74 new_arg_types.push_back(CPSType(t, answer));
75 }
76 new_arg_types.push_back(Arrow(CPSType(f->ret_type, answer), answer));
77 return FuncType(new_arg_types, answer, f->type_params, f->type_constraints);
78}
79
80Type CPSType(const Type& t, const TypeVar& answer) {
81 struct CPSTypeMutator : TypeMutator {
82 explicit CPSTypeMutator(const TypeVar& answer) : answer(answer) {}
83 TypeVar answer;
84 Type VisitType_(const FuncTypeNode* t) final {
85 return CPSFuncType(GetRef<FuncType>(t), answer);
86 }
87 } mut(answer);
88 return mut(t);
89}
90
91// transform global functions into cps form.
92using CPSMap = std::unordered_map<GlobalVar, GlobalVar, ObjectPtrHash, ObjectPtrEqual>;
93
94// transform vars from the original program into new vars, so their type will be correct.
95using VarMap = std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual>;
96
97/*
98 * The meta continuation.
99 * There is 3 rules on the metacontinuation:
100 * 0: It can only use the argument once.
101 * The argument is code, and using it twice will duplicate code.
102 * Bound the argument via let instead.
103 * 1: If the size of the metacontinuation is unbounded, it can only be called once.
104 * It contain code, so calling it twice duplicate code.
105 * Reify the continuation and bound it instead.
106 * See the function 'reify' and the if case for more detail.
107 * 2: The argument must be effect free.
108 * It might reorder or drop the argument.
109 * Again, bound the argument via let instead.
110 * See the call case for more detail.
111 */
112using MCont = std::function<Expr(const Expr&)>;
113
114Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm);
115
116Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm, VarMap* vm,
117 const TypeVar& answer) {
118 std::function<Var(Var)> remap = [&](const Var& v) { return vm->count(v) == 0 ? v : vm->at(v); };
119 auto function_type = Downcast<FuncType>(f->checked_type());
120 // Each MCont can be used at most once.
121 struct CPSFunctor : ExprFunctor<Expr(const Expr&, const MCont&)>, PatternMutator {
122 CPSFunctor(const std::function<Var(Var)>& remap, const TypeVar& answer, const IRModule& m,
123 VarMap* vm, CPSMap* cm)
124 : remap(remap), answer(answer), m(m), vm(vm), cm(cm) {}
125 const std::function<Var(Var)>& remap;
126 TypeVar answer;
127 IRModule m;
128 VarMap* vm;
129 CPSMap* cm;
130
131 Expr VisitExpr_(const LetNode* op, const MCont& k) final {
132 return VisitExpr(
133 op->value, [&](const Expr& v) { return Let(remap(op->var), v, VisitExpr(op->body, k)); });
134 }
135
136 Expr VisitExpr_(const FunctionNode* op, const MCont& k) final {
137 ICHECK(!op->HasNonzeroAttr(attr::kPrimitive)) << "primitive func not supported yet.";
138 return k(ToCPS(GetRef<Function>(op), m, cm, vm, answer));
139 }
140
141 Expr VisitExpr_(const ConstantNode* op, const MCont& k) final {
142 return k(GetRef<Constant>(op));
143 }
144
145 Expr VisitExpr_(const VarNode* op, const MCont& k) final { return k(remap(GetRef<Var>(op))); }
146
147 Pattern VisitPattern_(const PatternVarNode* op) final { return PatternVar(remap(op->var)); }
148
149 Expr VisitExpr_(const GlobalVarNode* op, const MCont& k) final {
150 auto gv = GetRef<GlobalVar>(op);
151 if (cm->count(gv) == 0) {
152 // only look unfold non-external calls.
153 BaseFunc base_func = m->Lookup(gv);
154 if (auto* n = base_func.as<FunctionNode>()) {
155 auto cps_gv = GlobalVar(std::string(gv->name_hint) + "_cps");
156 cm->insert({gv, cps_gv});
157 m->Add(cps_gv, ToCPS(GetRef<Function>(n), m, cm));
158 } else {
159 // return the original global var if it is
160 // an external call to non-relay function.
161 return GetRef<GlobalVar>(op);
162 }
163 }
164 return k(cm->at(gv));
165 }
166
167 Expr VisitExpr_(const RefCreateNode* op, const MCont& k) final {
168 return VisitExpr(op->value, [&](const Expr& v) { return k(RefCreate(v)); });
169 }
170
171 Expr reify(const MCont& k) {
172 Var arg = Var("arg", Type());
173 return Function({arg}, k(arg), Type(), {}, {});
174 }
175
176 Expr reify(const MCont& k, const std::function<Expr(MCont)>& cont) {
177 return LetList::LetBind(reify(k), [&](const Var& f) {
178 return cont([&](const Expr& e) { return Call(f, {e}); });
179 });
180 }
181
182 Expr VisitExpr_(const IfNode* op, const MCont& k) final {
183 return reify(k, [&](const MCont& kf) {
184 return VisitExpr(op->cond, [&](const Expr& v) {
185 return If(v, VisitExpr(op->true_branch, kf), VisitExpr(op->false_branch, kf));
186 });
187 });
188 }
189
190 Expr VisitExpr_(const MatchNode* op, const MCont& k) final {
191 return reify(k, [&](const MCont& kf) {
192 return VisitExpr(op->data, [&](const Expr& v) {
193 tvm::Array<Clause> clauses;
194 for (const auto& c : op->clauses) {
195 clauses.push_back(Clause(VisitPattern(c->lhs), VisitExpr(c->rhs, kf)));
196 }
197 return Match(v, clauses, op->complete);
198 });
199 });
200 }
201
202 Expr VisitExpr_(const RefReadNode* op, const MCont& k) final {
203 return VisitExpr(op->ref, [&](const Expr& r) { return LetList::LetBind(RefRead(r), k); });
204 }
205
206 Expr VisitExpr_(const RefWriteNode* op, const MCont& k) final {
207 return VisitExpr(op->ref, [&](const Expr& r) {
208 return VisitExpr(op->value,
209 [&](const Expr& v) { return LetList::LetBind(RefWrite(r, v), k); });
210 });
211 }
212
213 Expr VisitExpr_(const TupleNode* tuple_node, const MCont& k) final {
214 tvm::Array<Expr> fields;
215 fields.reserve(tuple_node->fields.size());
216 std::function<Expr()> next;
217 next = [&]() {
218 return (fields.size() == tuple_node->fields.size())
219 ? k(WithFields(GetRef<Tuple>(tuple_node), fields))
220 : VisitExpr(tuple_node->fields[fields.size()], [&](const Expr& v) {
221 fields.push_back(v);
222 return next();
223 });
224 };
225 return next();
226 }
227
228 Expr VisitExpr_(const TupleGetItemNode* op, const MCont& k) final {
229 return VisitExpr(op->tuple, [&](const Expr& v) { return k(TupleGetItem(v, op->index)); });
230 }
231
232 Expr VisitExpr_(const CallNode* op, const MCont& k) final {
233 if (op->op.as<OpNode>() || op->op.as<ConstructorNode>()) {
234 tvm::Array<Expr> args;
235 std::function<Expr()> next;
236 next = [&]() {
237 if (args.size() == op->args.size()) {
238 return LetList::LetBind(Call(op->op, args, op->attrs, op->type_args), k);
239 } else {
240 return VisitExpr(op->args[args.size()], [&](const Expr& v) {
241 args.push_back(v);
242 return next();
243 });
244 }
245 };
246 return next();
247 } else {
248 Expr f;
249 tvm::Array<Expr> args;
250 std::function<Expr()> next;
251 next = [&]() {
252 if (args.size() == op->args.size()) {
253 args.push_back(reify(k));
254 return Expr(Call(f, args, op->attrs, op->type_args));
255 } else {
256 return VisitExpr(op->args[args.size()], [&](const Expr& v) {
257 args.push_back(v);
258 return next();
259 });
260 }
261 };
262 return VisitExpr(op->op, [&](const Expr& v) {
263 f = v;
264 return next();
265 });
266 }
267 }
268 } mut(remap, answer, m, vm, cm);
269 Var k = Var("k", Arrow(CPSType(function_type->ret_type, answer), answer));
270 tvm::Array<Var> new_params;
271 for (const Var& v : f->params) {
272 new_params.push_back(remap(v));
273 }
274 new_params.push_back(k);
275 return WithFields(f, new_params,
276 mut.VisitExpr(f->body, [&](const Expr& e) { return Call(k, {e}); }), answer);
277}
278
279Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm) {
280 TypeVar answer = TypeVar("answer", kType);
281 VarMap var;
282 struct Remapper : ExprVisitor, PatternVisitor {
283 Remapper(const TypeVar& answer, VarMap* vm) : answer(answer), vm(vm) {}
284 TypeVar answer;
285 VarMap* vm;
286 void VisitExpr_(const VarNode* vn) final {
287 Var v = GetRef<Var>(vn);
288 if (vm->count(v) == 0) {
289 auto ret = Var(v->name_hint(), CPSType(v->checked_type(), answer));
290 vm->insert({v, ret});
291 }
292 }
293
294 void VisitPattern(const Pattern& p) final { PatternVisitor::VisitPattern(p); }
295
296 void VisitPattern_(const PatternVarNode* op) final { VisitExpr(op->var); }
297 } remap(answer, &var);
298 remap.VisitExpr(f);
299 Function ret = ToCPS(f, m, cm, &var, answer);
300 auto new_type_params = ret->type_params;
301 new_type_params.push_back(answer);
302 return WithFields(ret, ret->params, ret->body, ret->ret_type, new_type_params);
303}
304
305Function ToCPS(const Function& f, const IRModule& m) {
306 CheckFeature(f, m, FeatureSet::All() - fGraph);
307 CPSMap cps;
308 return ToCPS(f, m, &cps);
309}
310
311Function UnCPS(const Function& f) {
312 CheckFeature(f, FeatureSet::All() - fGraph);
313 ICHECK_GT(f->params.size(), 0);
314 Array<Var> new_params;
315 for (const auto& p : f->params) {
316 new_params.push_back(Var(p->name_hint(), p->checked_type()));
317 }
318 auto cont_type = Downcast<FuncType>(new_params.back()->type_annotation);
319 new_params.pop_back();
320 ICHECK_EQ(cont_type->arg_types.size(), 1);
321 auto new_ret_type = Type(cont_type->arg_types[0]);
322 Array<TypeVar> new_type_params;
323 for (const auto& tp : f->type_params) {
324 new_type_params.push_back(TypeVar(tp->name_hint, tp->kind));
325 }
326 auto answer_type = new_type_params.back();
327 new_type_params.pop_back();
328 // TODO(@M.K.): make alphaequal work on free term
329 // ICHECK(tvm::StructuralEqual()(cont_type, Arrow(new_ret_type, answer_type)));
330 auto x = Var("x", new_ret_type);
331 auto cont = Function({x}, x, new_ret_type, {}, {});
332 tvm::Array<Expr> args;
333 for (const auto& p : new_params) {
334 args.push_back(p);
335 }
336 args.push_back(cont);
337 tvm::Array<Type> type_args;
338 for (const auto& tp : new_type_params) {
339 type_args.push_back(tp);
340 }
341 type_args.push_back(new_ret_type);
342 return WithFields(f, new_params, Call(f, args, {}, type_args), new_ret_type, new_type_params);
343}
344
345TVM_REGISTER_GLOBAL("relay._transform.to_cps")
346 .set_body_typed(static_cast<Function (*)(const Function&, const IRModule&)>(ToCPS));
347
348TVM_REGISTER_GLOBAL("relay._transform.un_cps").set_body_typed(UnCPS);
349
350namespace transform {
351
352Pass ToCPS() {
353 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
354 [=](Function f, IRModule m, PassContext pc) { return Function(ToCPS(f, m)); };
355 return CreateFunctionPass(pass_func, 1, "ToCPS", {});
356}
357
358TVM_REGISTER_GLOBAL("relay._transform.ToCPS").set_body_typed(ToCPS);
359
360Pass UnCPS() {
361 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
362 [=](Function f, IRModule m, PassContext pc) { return Function(UnCPS(f)); };
363 return CreateFunctionPass(pass_func, 1, "UnCPS", {});
364}
365
366TVM_REGISTER_GLOBAL("relay._transform.UnCPS").set_body_typed(UnCPS);
367
368} // namespace transform
369
370} // namespace relay
371} // namespace tvm
372