1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 *
22 * \file to_a_normal_form.cc
23 *
24 * \brief Turn implicit sharing into observable sharing.
25 */
26#include <tvm/relay/analysis.h>
27#include <tvm/relay/expr_functor.h>
28#include <tvm/relay/transform.h>
29#include <tvm/runtime/logging.h>
30
31#include "../../support/arena.h"
32#include "../analysis/dependency_graph.h"
33#include "../op/annotation/annotation.h"
34#include "./device_aware_visitors.h"
35#include "./let_list.h"
36#include "./pass_utils.h"
37
38namespace tvm {
39namespace relay {
40
41Scope LCA(Scope lhs, Scope rhs) {
42 while (lhs != rhs) {
43 if (lhs->level > rhs->level) {
44 lhs = lhs->parent;
45 } else if (lhs->level < rhs->level) {
46 rhs = rhs->parent;
47 } else {
48 lhs = lhs->parent;
49 rhs = rhs->parent;
50 }
51 }
52 return lhs;
53}
54
55std::pair<NodeScopeMap, ExprSet> CalcScope(const DependencyGraph& dg) {
56 NodeScopeMap expr_scope;
57 ExprSet lifted_exprs;
58 std::unordered_map<DependencyGraph::Node*, Expr> node_to_expr;
59 for (auto expr_node : dg.expr_node) {
60 node_to_expr[expr_node.second] = expr_node.first;
61 }
62 bool global_scope_used = false;
63 Scope global_scope = std::make_shared<ScopeNode>();
64
65 for (auto it = dg.post_dfs_order.rbegin(); it != dg.post_dfs_order.rend(); ++it) {
66 DependencyGraph::Node* n = *it;
67 auto iit = n->parents.head;
68 Scope s;
69 if (iit == nullptr) {
70 ICHECK(!global_scope_used);
71 s = global_scope;
72 global_scope_used = true;
73 } else {
74 s = expr_scope.at(iit->value);
75 const auto original_s = s;
76 iit = iit->next;
77 for (; iit != nullptr; iit = iit->next) {
78 s = LCA(s, expr_scope.at(iit->value));
79 }
80 if (s != original_s && node_to_expr.find(n) != node_to_expr.end()) {
81 // filter out exprs whose scope do not matter
82 Expr expr = node_to_expr[n];
83 if (!expr.as<OpNode>()) {
84 lifted_exprs.insert(expr);
85 }
86 }
87 }
88 if (n->new_scope) {
89 auto child_scope = std::make_shared<ScopeNode>(s);
90 expr_scope.insert({n, child_scope});
91 } else {
92 expr_scope.insert({n, s});
93 }
94 }
95 ICHECK(global_scope_used);
96 return std::make_pair(expr_scope, lifted_exprs);
97}
98
99namespace {
100
101/* Special care is needed to handle local recursion.
102 * Fill additionally take a (possibly null) Var argument,
103 * If it is not null, Fill is required to bind the transformed result to that var.
104 *
105 * ToANormalForm and PlanDevices
106 * -----------------------------
107 * If PlanDevices has run this transform must respect the lexical scoping rules for the residual
108 * "on_device" calls. Eg:
109 * \code
110 * on_device(add(subtract(x, y), add(y, z)), device_type=2, is_fixed=true)
111 * ==>
112 * let %x0 = on_device(subtract(x, y), device_type=2, is_fixed=true)
113 * let %x1 = on_device(add(y, z), device_type=2, is_fixed=true)
114 * let %x2 = on_device(add(%x0, %x1), device_type=2, is_fixed=true)
115 * %x2
116 * \endcode
117 *
118 * In addition to conversion to ANF this pass is also handling hoisting implicitly shared
119 * sub-expressions to the inner-most scope common to all their uses:
120 * \code
121 * on_device(
122 * if y {
123 * on_device(%0, device_type=2, is_fixed=true)
124 * } else {
125 * on_device(subtract(%0, b), device_type=2, is_fixed=true)
126 * },
127 * device_type=1, is_fixed=true)
128 * (where %0 = add(a, b))
129 * ==>
130 * let %x0 = on_device(add(a, b), device_type=2, is_fixed=true);
131 * on_device(
132 * if y {
133 * on_device(%x0, device_type=2, is_fixed=true)
134 * } else {
135 * let %x1 = on_device(subtract(%x0, b), device_type=2, is_fixed=true);
136 * %x1
137 * },
138 * device_type=1, is_fixed=true)
139 * \endcode
140 * Though the PlanDevices has already avoided inserting "on_device" calls where they are redundant
141 * due to lexical scope, it's fiddly to do the same in this pass since the notion of 'scope' is
142 * now determined by the scope map. So we'll just insert them mechanically on every let-binding.
143 *
144 * TODO(mbs): Rewrite to derive from DeviceAwareExprMutator and not track device types
145 * explicitly. It's easy to get rid of the need for the extra var argument on VisitExpr by shifting
146 * the recursion a '1/2 step' to return a possibly compound expression who's inner expressions are
147 * all atomic. However the use of the scope map is currently subtle enough I want to leave it
148 * alone for now.
149 */
150class Fill : ExprFunctor<Expr(const Expr&, const Var&)>, private transform::LexicalOnDeviceMixin {
151 public:
152 static Expr ToANormalForm(const Expr& e, const DependencyGraph& dg, NodeScopeMap* node_scope) {
153 Fill fi(dg, node_scope, nullptr);
154 return fi.GetScope(e)->let_list->Get(fi.VisitExpr(e));
155 }
156
157 // For basic block normal form, bind expressions only if the original expression's scope
158 // should be lifted
159 static Expr ToBasicBlockNormalForm(const Expr& e, const DependencyGraph& dg,
160 NodeScopeMap* node_scope, ExprSet* lifted) {
161 Fill fi(dg, node_scope, lifted);
162 return fi.GetScope(e)->let_list->Get(fi.VisitExpr(e));
163 }
164
165 private:
166 // Note: Conversion to ANF needn't care about the devices for global vars since all that can
167 // happen with them is to go from:
168 // ...@g...
169 // to:
170 // let %x = @g;
171 // ...
172 // ...%x...
173 // In that case the code will ask for the device for @g, get kInvalidDeviceType, then
174 // MaybeOnDevice @g, which is always a no-op.
175 Fill(const DependencyGraph& dg, NodeScopeMap* node_scope, ExprSet* include_set)
176 : transform::LexicalOnDeviceMixin(Optional<IRModule>()),
177 dg_(dg),
178 node_scope_(node_scope),
179 include_set_(include_set) {}
180
181 Scope GetScope(const Expr& e) { return node_scope_->at(dg_.expr_node.at(e)); }
182
183 Scope GetSubScope(const Expr& e, size_t i) {
184 DependencyGraph::Node* n = dg_.expr_node.at(e);
185 auto h = n->children.head;
186 while (i != 0) {
187 ICHECK(h);
188 --i;
189 h = h->next;
190 }
191 ICHECK(h);
192 return node_scope_->at(h->value);
193 }
194
195 Expr VisitExpr(const Expr& e) { return this->VisitExpr(e, Var()); }
196
197 Expr VisitExpr(const Expr& e, const Var& v) final {
198 if (memo.count(e) == 0) {
199 memo.insert({e, ExprFunctor<Expr(const Expr&, const Var&)>::VisitExpr(e, v)});
200 } else if (v.defined()) {
201 GetScope(e)->let_list->Push(v, memo.at(e));
202 }
203 auto ret = memo.at(e);
204 // if no include_set is specified, every expression should be atomic.
205 // TODO(mbs): Note that Constants must be let-bound even though they are considered 'atomic'
206 // by this test.
207 if (include_set_ == nullptr && function_nesting() > 0) {
208 ICHECK(IsAtomic(ret)) << "expression:" << std::endl << PrettyPrint(ret);
209 }
210 return ret;
211 }
212
213 Expr Atomic(const Expr& e, const Var& v) {
214 Expr annotated_expr = MaybeOnDeviceFixed(e, GetVirtualDevice(e));
215 return v.defined() ? GetScope(e)->let_list->Push(v, annotated_expr) : annotated_expr;
216 }
217
218 // Bind expression `now` to var `v` if the original expression is in the include set, or if
219 // v is already defined (e.g. coming from a Let expression). Otherwise return `now` directly
220 Expr Compound(const Expr& orig, const Expr& now, const Var& v) {
221 Expr annotated_expr = MaybeOnDeviceFixed(now, GetVirtualDevice(orig));
222 Var var = v.defined() ? v : Var::GenSym();
223 bool not_included = include_set_ && include_set_->find(orig) == include_set_->end();
224 if (!v.defined() && not_included) {
225 return annotated_expr;
226 } else if (const LetNode* let = AsIgnoringOnDevice<LetNode>(now)) {
227 // Instead of making a nested binding "let var = (let x = ...; bindings...; body)", we push
228 // the inner bindings into the outer scope and bind body to var, giving
229 // "let x = ...; bindings...; let var = body;" as the resulting bindings.
230 Expr e = GetRef<Expr>(let);
231 while (const LetNode* inner_let = AsIgnoringOnDevice<LetNode>(e)) {
232 GetScope(orig)->let_list->Push(inner_let->var, inner_let->value);
233 e = inner_let->body;
234 }
235 Expr annotated_body = MaybeOnDeviceFixed(e, GetVirtualDevice(orig));
236 return GetScope(orig)->let_list->Push(var, annotated_body);
237 } else {
238 return GetScope(orig)->let_list->Push(var, annotated_expr);
239 }
240 }
241
242 Expr VisitExpr_(const CallNode* c, const Var& v) final {
243 OnDeviceProps props = GetOnDeviceProps(c);
244 if (props.body.defined() && props.is_fixed()) {
245 // Keep track of expression device type for lexically enclosing sub-expressions.
246 PushVirtualDevice(props.virtual_device);
247 Expr body = VisitExpr(props.body, v);
248 // We are done with this sub-expression.
249 PopVirtualDevice();
250 // Preserve the "on_device" annotations.
251 return OnDeviceWithProps(body, props);
252 }
253
254 Expr e = GetRef<Expr>(c);
255 std::vector<Expr> args;
256 for (const auto& a : c->args) {
257 args.push_back(VisitExpr(a));
258 }
259 return Compound(e, Call(VisitExpr(c->op), args, c->attrs, c->type_args), v);
260 }
261
262 Expr VisitExpr_(const TupleNode* tuple_node, const Var& v) final {
263 Expr e = GetRef<Expr>(tuple_node);
264 Array<Expr> fields;
265 fields.reserve(tuple_node->fields.size());
266 for (const auto& a : tuple_node->fields) {
267 fields.push_back(VisitExpr(a));
268 }
269 return Compound(e, WithFields(GetRef<Tuple>(tuple_node), fields), v);
270 }
271
272 Expr VisitExpr_(const TupleGetItemNode* t, const Var& v) final {
273 Expr e = GetRef<Expr>(t);
274 return Compound(e, TupleGetItem(VisitExpr(t->tuple), t->index), v);
275 }
276
277 Expr VisitExpr_(const RefCreateNode* r, const Var& v) final {
278 Expr e = GetRef<Expr>(r);
279 return Compound(e, RefCreate(VisitExpr(r->value)), v);
280 }
281
282 Expr VisitExpr_(const RefReadNode* r, const Var& v) final {
283 Expr e = GetRef<Expr>(r);
284 return Compound(e, RefRead(VisitExpr(r->ref)), v);
285 }
286
287 Expr VisitExpr_(const RefWriteNode* r, const Var& v) final {
288 Expr e = GetRef<Expr>(r);
289 return Compound(e, RefWrite(VisitExpr(r->ref), VisitExpr(r->value)), v);
290 }
291
292 Expr VisitExpr_(const IfNode* i, const Var& v) final {
293 Expr e = GetRef<Expr>(i);
294 Expr ret = If(VisitExpr(i->cond), GetSubScope(e, 1)->let_list->Get(VisitExpr(i->true_branch)),
295 GetSubScope(e, 2)->let_list->Get(VisitExpr(i->false_branch)));
296 return Compound(e, ret, v);
297 }
298
299 Expr VisitExpr_(const FunctionNode* f, const Var& v) final {
300 Expr e = GetRef<Expr>(f);
301 Expr ret;
302 if (f->HasNonzeroAttr(attr::kPrimitive)) {
303 ret = e;
304 } else {
305 // Keep track of expression and bound variable device types for lexically enclosing
306 // sub-expressions.
307 PushVirtualDevice(f->virtual_device());
308 for (auto param : f->params) {
309 PushBoundVar(param, param->virtual_device());
310 }
311 EnterFunctionBody();
312 ret = WithFields(GetRef<Function>(f), f->params,
313 GetSubScope(e, 0)->let_list->Get(VisitExpr(f->body)));
314 // We are done with this function.
315 ExitFunctionBody();
316 for (size_t i = 0; i < f->params.size(); ++i) {
317 PopBoundVar(f->params[i]);
318 }
319 PopVirtualDevice();
320 }
321 if (function_nesting() == 0) {
322 ICHECK(!v.defined());
323 // This is a global function which can be bound directly in the module.
324 return ret;
325 } else {
326 // This is a local function which must be let-bound.
327 return Compound(e, ret, v);
328 }
329 }
330
331 Expr VisitExpr_(const LetNode* l, const Var& v) final {
332 Expr e = GetRef<Expr>(l);
333 // Keep track of bound variable device types for lexically enclosing sub-expressions.
334 PushBoundVar(l->var, GetVirtualDevice(l->value));
335 VisitExpr(l->value, l->var);
336 Expr ret = GetSubScope(e, 0)->let_list->Get(VisitExpr(l->body));
337 // We are done with these sub-expressions.
338 PopBoundVar(l->var);
339 return Compound(e, ret, v);
340 }
341
342 Expr VisitExpr_(const ConstantNode* c, const Var& v) final {
343 Expr e = GetRef<Expr>(c);
344 return Compound(e, e, v);
345 }
346
347 Expr VisitExpr_(const VarNode* vn, const Var& v) final {
348 Expr e = GetRef<Expr>(vn);
349 return Atomic(e, v);
350 }
351
352 Expr VisitExpr_(const GlobalVarNode* gvn, const Var& v) final {
353 GlobalVar gv = GetRef<GlobalVar>(gvn);
354 return Atomic(gv, v);
355 }
356
357 Expr VisitExpr_(const OpNode* op, const Var& v) final {
358 Expr e = GetRef<Expr>(op);
359 return Atomic(e, v);
360 }
361
362 Expr VisitExpr_(const ConstructorNode* c, const Var& v) final {
363 Expr e = GetRef<Expr>(c);
364 return Atomic(e, v);
365 }
366
367 Expr VisitExpr_(const MatchNode* m, const Var& v) final {
368 Expr e = GetRef<Expr>(m);
369 Expr data = VisitExpr(m->data);
370 std::vector<Clause> clauses;
371 for (const Clause& c : m->clauses) {
372 clauses.emplace_back(c->lhs,
373 GetSubScope(e, 1 + clauses.size())->let_list->Get(VisitExpr(c->rhs)));
374 }
375 return Compound(e, Match(data, clauses, m->complete), v);
376 }
377
378 const DependencyGraph& dg_;
379 NodeScopeMap* node_scope_ = nullptr;
380 std::unordered_map<Expr, Expr, ObjectPtrHash, ObjectPtrEqual> memo;
381 // a set of Expressions to include for let bindings. If set to nullptr
382 // all Exprs will be pushed to the let list.
383 ExprSet* include_set_ = nullptr;
384};
385
386IRModule ModuleToANormalForm(const IRModule& mod) {
387 tvm::Map<GlobalVar, Function> updates;
388 auto funcs = mod->functions;
389 for (const auto& it : funcs) {
390 ICHECK_EQ(FreeVars(it.second).size(), 0);
391 if (const auto* n = it.second.as<FunctionNode>()) {
392 if (n->GetAttr<String>(attr::kCompiler).defined()) continue;
393 Function func = GetRef<Function>(n);
394 Function ret = Downcast<Function>(transform::ToANormalForm(func));
395 ICHECK_EQ(FreeVars(ret).size(), 0) << "rewritten:" << std::endl
396 << PrettyPrint(ret) << std::endl
397 << "should not have free vars: " << FreeVars(ret);
398 VLOG(1) << "rewritten:" << std::endl
399 << PrettyPrint(func) << std::endl
400 << "to ANF:" << std::endl
401 << PrettyPrint(ret);
402 updates.Set(it.first, ret);
403 }
404 }
405
406 for (auto pair : updates) {
407 mod->Add(pair.first, pair.second, true);
408 }
409
410 return mod;
411}
412
413} // namespace
414
415Expr ToBasicBlockNormalFormAux(const Expr& e) {
416 // calculate all the dependency between nodes.
417 support::Arena arena;
418 DependencyGraph dg = DependencyGraph::Create(&arena, e);
419 /* The scope of the whole expr is global.
420 * The scope of any subexpr, is the lowest common ancestor of all incoming edge.
421 * We also record the set of expressions whose scope is lifted.
422 */
423 std::pair<NodeScopeMap, ExprSet> scopes = CalcScope(dg);
424 return Fill::ToBasicBlockNormalForm(e, dg, &scopes.first, &scopes.second);
425}
426
427namespace transform {
428
429Expr ToANormalForm(const Expr& e) {
430 /* When you lift a lambda, what is inside is also being lift.
431 *
432 * So we must determine the scope of the lambda before determining the scope of it's body.
433 *
434 * To make this more principled,
435 * we always determine the scope of parent before determining the scope of children.
436 *
437 * So we calculate all the dependency between nodes.
438 */
439 support::Arena arena;
440 DependencyGraph dg = DependencyGraph::Create(&arena, e);
441 /* In order to model new subscopes created by lambda, if else and pattern matching,
442 * we also assign scope to edge as well.
443 * The scope of an edge is either the parent's scope, or a new subscope of the parent's scope.
444 *
445 * So, the scope of the whole expr is global.
446 * The scope of any subexpr, is the lowest common ancestor of all incoming edge.
447 *
448 * Every scope additionally contain a LetList which collect all value of that scope.
449 * We do an additional pass to fill all the LetList and we are done.
450 */
451 std::pair<NodeScopeMap, ExprSet> scopes = CalcScope(dg);
452 return Fill::ToANormalForm(e, dg, &scopes.first);
453}
454
455Pass ToANormalForm() {
456 runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
457 [=](IRModule m, PassContext pc) { return ModuleToANormalForm(m); };
458 return CreateModulePass(pass_func, 1, "ToANormalForm", {});
459}
460
461TVM_REGISTER_GLOBAL("relay._transform.ToANormalForm").set_body_typed([]() {
462 return ToANormalForm();
463});
464
465TVM_REGISTER_GLOBAL("relay._transform.ToANormalFormExpr").set_body_typed([](const Expr& e) {
466 return ToANormalForm(e);
467});
468
469} // namespace transform
470
471} // namespace relay
472} // namespace tvm
473