1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * |
22 | * \file to_a_normal_form.cc |
23 | * |
24 | * \brief Turn implicit sharing into observable sharing. |
25 | */ |
26 | #include <tvm/relay/analysis.h> |
27 | #include <tvm/relay/expr_functor.h> |
28 | #include <tvm/relay/transform.h> |
29 | #include <tvm/runtime/logging.h> |
30 | |
31 | #include "../../support/arena.h" |
32 | #include "../analysis/dependency_graph.h" |
33 | #include "../op/annotation/annotation.h" |
34 | #include "./device_aware_visitors.h" |
35 | #include "./let_list.h" |
36 | #include "./pass_utils.h" |
37 | |
38 | namespace tvm { |
39 | namespace relay { |
40 | |
41 | Scope LCA(Scope lhs, Scope rhs) { |
42 | while (lhs != rhs) { |
43 | if (lhs->level > rhs->level) { |
44 | lhs = lhs->parent; |
45 | } else if (lhs->level < rhs->level) { |
46 | rhs = rhs->parent; |
47 | } else { |
48 | lhs = lhs->parent; |
49 | rhs = rhs->parent; |
50 | } |
51 | } |
52 | return lhs; |
53 | } |
54 | |
55 | std::pair<NodeScopeMap, ExprSet> CalcScope(const DependencyGraph& dg) { |
56 | NodeScopeMap expr_scope; |
57 | ExprSet lifted_exprs; |
58 | std::unordered_map<DependencyGraph::Node*, Expr> node_to_expr; |
59 | for (auto expr_node : dg.expr_node) { |
60 | node_to_expr[expr_node.second] = expr_node.first; |
61 | } |
62 | bool global_scope_used = false; |
63 | Scope global_scope = std::make_shared<ScopeNode>(); |
64 | |
65 | for (auto it = dg.post_dfs_order.rbegin(); it != dg.post_dfs_order.rend(); ++it) { |
66 | DependencyGraph::Node* n = *it; |
67 | auto iit = n->parents.head; |
68 | Scope s; |
69 | if (iit == nullptr) { |
70 | ICHECK(!global_scope_used); |
71 | s = global_scope; |
72 | global_scope_used = true; |
73 | } else { |
74 | s = expr_scope.at(iit->value); |
75 | const auto original_s = s; |
76 | iit = iit->next; |
77 | for (; iit != nullptr; iit = iit->next) { |
78 | s = LCA(s, expr_scope.at(iit->value)); |
79 | } |
80 | if (s != original_s && node_to_expr.find(n) != node_to_expr.end()) { |
81 | // filter out exprs whose scope do not matter |
82 | Expr expr = node_to_expr[n]; |
83 | if (!expr.as<OpNode>()) { |
84 | lifted_exprs.insert(expr); |
85 | } |
86 | } |
87 | } |
88 | if (n->new_scope) { |
89 | auto child_scope = std::make_shared<ScopeNode>(s); |
90 | expr_scope.insert({n, child_scope}); |
91 | } else { |
92 | expr_scope.insert({n, s}); |
93 | } |
94 | } |
95 | ICHECK(global_scope_used); |
96 | return std::make_pair(expr_scope, lifted_exprs); |
97 | } |
98 | |
99 | namespace { |
100 | |
101 | /* Special care is needed to handle local recursion. |
102 | * Fill additionally take a (possibly null) Var argument, |
103 | * If it is not null, Fill is required to bind the transformed result to that var. |
104 | * |
105 | * ToANormalForm and PlanDevices |
106 | * ----------------------------- |
107 | * If PlanDevices has run this transform must respect the lexical scoping rules for the residual |
108 | * "on_device" calls. Eg: |
109 | * \code |
110 | * on_device(add(subtract(x, y), add(y, z)), device_type=2, is_fixed=true) |
111 | * ==> |
112 | * let %x0 = on_device(subtract(x, y), device_type=2, is_fixed=true) |
113 | * let %x1 = on_device(add(y, z), device_type=2, is_fixed=true) |
114 | * let %x2 = on_device(add(%x0, %x1), device_type=2, is_fixed=true) |
115 | * %x2 |
116 | * \endcode |
117 | * |
118 | * In addition to conversion to ANF this pass is also handling hoisting implicitly shared |
119 | * sub-expressions to the inner-most scope common to all their uses: |
120 | * \code |
121 | * on_device( |
122 | * if y { |
123 | * on_device(%0, device_type=2, is_fixed=true) |
124 | * } else { |
125 | * on_device(subtract(%0, b), device_type=2, is_fixed=true) |
126 | * }, |
127 | * device_type=1, is_fixed=true) |
128 | * (where %0 = add(a, b)) |
129 | * ==> |
130 | * let %x0 = on_device(add(a, b), device_type=2, is_fixed=true); |
131 | * on_device( |
132 | * if y { |
133 | * on_device(%x0, device_type=2, is_fixed=true) |
134 | * } else { |
135 | * let %x1 = on_device(subtract(%x0, b), device_type=2, is_fixed=true); |
136 | * %x1 |
137 | * }, |
138 | * device_type=1, is_fixed=true) |
139 | * \endcode |
140 | * Though the PlanDevices has already avoided inserting "on_device" calls where they are redundant |
141 | * due to lexical scope, it's fiddly to do the same in this pass since the notion of 'scope' is |
142 | * now determined by the scope map. So we'll just insert them mechanically on every let-binding. |
143 | * |
144 | * TODO(mbs): Rewrite to derive from DeviceAwareExprMutator and not track device types |
145 | * explicitly. It's easy to get rid of the need for the extra var argument on VisitExpr by shifting |
146 | * the recursion a '1/2 step' to return a possibly compound expression who's inner expressions are |
147 | * all atomic. However the use of the scope map is currently subtle enough I want to leave it |
148 | * alone for now. |
149 | */ |
150 | class Fill : ExprFunctor<Expr(const Expr&, const Var&)>, private transform::LexicalOnDeviceMixin { |
151 | public: |
152 | static Expr ToANormalForm(const Expr& e, const DependencyGraph& dg, NodeScopeMap* node_scope) { |
153 | Fill fi(dg, node_scope, nullptr); |
154 | return fi.GetScope(e)->let_list->Get(fi.VisitExpr(e)); |
155 | } |
156 | |
157 | // For basic block normal form, bind expressions only if the original expression's scope |
158 | // should be lifted |
159 | static Expr ToBasicBlockNormalForm(const Expr& e, const DependencyGraph& dg, |
160 | NodeScopeMap* node_scope, ExprSet* lifted) { |
161 | Fill fi(dg, node_scope, lifted); |
162 | return fi.GetScope(e)->let_list->Get(fi.VisitExpr(e)); |
163 | } |
164 | |
165 | private: |
166 | // Note: Conversion to ANF needn't care about the devices for global vars since all that can |
167 | // happen with them is to go from: |
168 | // ...@g... |
169 | // to: |
170 | // let %x = @g; |
171 | // ... |
172 | // ...%x... |
173 | // In that case the code will ask for the device for @g, get kInvalidDeviceType, then |
174 | // MaybeOnDevice @g, which is always a no-op. |
175 | Fill(const DependencyGraph& dg, NodeScopeMap* node_scope, ExprSet* include_set) |
176 | : transform::LexicalOnDeviceMixin(Optional<IRModule>()), |
177 | dg_(dg), |
178 | node_scope_(node_scope), |
179 | include_set_(include_set) {} |
180 | |
181 | Scope GetScope(const Expr& e) { return node_scope_->at(dg_.expr_node.at(e)); } |
182 | |
183 | Scope GetSubScope(const Expr& e, size_t i) { |
184 | DependencyGraph::Node* n = dg_.expr_node.at(e); |
185 | auto h = n->children.head; |
186 | while (i != 0) { |
187 | ICHECK(h); |
188 | --i; |
189 | h = h->next; |
190 | } |
191 | ICHECK(h); |
192 | return node_scope_->at(h->value); |
193 | } |
194 | |
195 | Expr VisitExpr(const Expr& e) { return this->VisitExpr(e, Var()); } |
196 | |
197 | Expr VisitExpr(const Expr& e, const Var& v) final { |
198 | if (memo.count(e) == 0) { |
199 | memo.insert({e, ExprFunctor<Expr(const Expr&, const Var&)>::VisitExpr(e, v)}); |
200 | } else if (v.defined()) { |
201 | GetScope(e)->let_list->Push(v, memo.at(e)); |
202 | } |
203 | auto ret = memo.at(e); |
204 | // if no include_set is specified, every expression should be atomic. |
205 | // TODO(mbs): Note that Constants must be let-bound even though they are considered 'atomic' |
206 | // by this test. |
207 | if (include_set_ == nullptr && function_nesting() > 0) { |
208 | ICHECK(IsAtomic(ret)) << "expression:" << std::endl << PrettyPrint(ret); |
209 | } |
210 | return ret; |
211 | } |
212 | |
213 | Expr Atomic(const Expr& e, const Var& v) { |
214 | Expr annotated_expr = MaybeOnDeviceFixed(e, GetVirtualDevice(e)); |
215 | return v.defined() ? GetScope(e)->let_list->Push(v, annotated_expr) : annotated_expr; |
216 | } |
217 | |
218 | // Bind expression `now` to var `v` if the original expression is in the include set, or if |
219 | // v is already defined (e.g. coming from a Let expression). Otherwise return `now` directly |
220 | Expr Compound(const Expr& orig, const Expr& now, const Var& v) { |
221 | Expr annotated_expr = MaybeOnDeviceFixed(now, GetVirtualDevice(orig)); |
222 | Var var = v.defined() ? v : Var::GenSym(); |
223 | bool not_included = include_set_ && include_set_->find(orig) == include_set_->end(); |
224 | if (!v.defined() && not_included) { |
225 | return annotated_expr; |
226 | } else if (const LetNode* let = AsIgnoringOnDevice<LetNode>(now)) { |
227 | // Instead of making a nested binding "let var = (let x = ...; bindings...; body)", we push |
228 | // the inner bindings into the outer scope and bind body to var, giving |
229 | // "let x = ...; bindings...; let var = body;" as the resulting bindings. |
230 | Expr e = GetRef<Expr>(let); |
231 | while (const LetNode* inner_let = AsIgnoringOnDevice<LetNode>(e)) { |
232 | GetScope(orig)->let_list->Push(inner_let->var, inner_let->value); |
233 | e = inner_let->body; |
234 | } |
235 | Expr annotated_body = MaybeOnDeviceFixed(e, GetVirtualDevice(orig)); |
236 | return GetScope(orig)->let_list->Push(var, annotated_body); |
237 | } else { |
238 | return GetScope(orig)->let_list->Push(var, annotated_expr); |
239 | } |
240 | } |
241 | |
242 | Expr VisitExpr_(const CallNode* c, const Var& v) final { |
243 | OnDeviceProps props = GetOnDeviceProps(c); |
244 | if (props.body.defined() && props.is_fixed()) { |
245 | // Keep track of expression device type for lexically enclosing sub-expressions. |
246 | PushVirtualDevice(props.virtual_device); |
247 | Expr body = VisitExpr(props.body, v); |
248 | // We are done with this sub-expression. |
249 | PopVirtualDevice(); |
250 | // Preserve the "on_device" annotations. |
251 | return OnDeviceWithProps(body, props); |
252 | } |
253 | |
254 | Expr e = GetRef<Expr>(c); |
255 | std::vector<Expr> args; |
256 | for (const auto& a : c->args) { |
257 | args.push_back(VisitExpr(a)); |
258 | } |
259 | return Compound(e, Call(VisitExpr(c->op), args, c->attrs, c->type_args), v); |
260 | } |
261 | |
262 | Expr VisitExpr_(const TupleNode* tuple_node, const Var& v) final { |
263 | Expr e = GetRef<Expr>(tuple_node); |
264 | Array<Expr> fields; |
265 | fields.reserve(tuple_node->fields.size()); |
266 | for (const auto& a : tuple_node->fields) { |
267 | fields.push_back(VisitExpr(a)); |
268 | } |
269 | return Compound(e, WithFields(GetRef<Tuple>(tuple_node), fields), v); |
270 | } |
271 | |
272 | Expr VisitExpr_(const TupleGetItemNode* t, const Var& v) final { |
273 | Expr e = GetRef<Expr>(t); |
274 | return Compound(e, TupleGetItem(VisitExpr(t->tuple), t->index), v); |
275 | } |
276 | |
277 | Expr VisitExpr_(const RefCreateNode* r, const Var& v) final { |
278 | Expr e = GetRef<Expr>(r); |
279 | return Compound(e, RefCreate(VisitExpr(r->value)), v); |
280 | } |
281 | |
282 | Expr VisitExpr_(const RefReadNode* r, const Var& v) final { |
283 | Expr e = GetRef<Expr>(r); |
284 | return Compound(e, RefRead(VisitExpr(r->ref)), v); |
285 | } |
286 | |
287 | Expr VisitExpr_(const RefWriteNode* r, const Var& v) final { |
288 | Expr e = GetRef<Expr>(r); |
289 | return Compound(e, RefWrite(VisitExpr(r->ref), VisitExpr(r->value)), v); |
290 | } |
291 | |
292 | Expr VisitExpr_(const IfNode* i, const Var& v) final { |
293 | Expr e = GetRef<Expr>(i); |
294 | Expr ret = If(VisitExpr(i->cond), GetSubScope(e, 1)->let_list->Get(VisitExpr(i->true_branch)), |
295 | GetSubScope(e, 2)->let_list->Get(VisitExpr(i->false_branch))); |
296 | return Compound(e, ret, v); |
297 | } |
298 | |
299 | Expr VisitExpr_(const FunctionNode* f, const Var& v) final { |
300 | Expr e = GetRef<Expr>(f); |
301 | Expr ret; |
302 | if (f->HasNonzeroAttr(attr::kPrimitive)) { |
303 | ret = e; |
304 | } else { |
305 | // Keep track of expression and bound variable device types for lexically enclosing |
306 | // sub-expressions. |
307 | PushVirtualDevice(f->virtual_device()); |
308 | for (auto param : f->params) { |
309 | PushBoundVar(param, param->virtual_device()); |
310 | } |
311 | EnterFunctionBody(); |
312 | ret = WithFields(GetRef<Function>(f), f->params, |
313 | GetSubScope(e, 0)->let_list->Get(VisitExpr(f->body))); |
314 | // We are done with this function. |
315 | ExitFunctionBody(); |
316 | for (size_t i = 0; i < f->params.size(); ++i) { |
317 | PopBoundVar(f->params[i]); |
318 | } |
319 | PopVirtualDevice(); |
320 | } |
321 | if (function_nesting() == 0) { |
322 | ICHECK(!v.defined()); |
323 | // This is a global function which can be bound directly in the module. |
324 | return ret; |
325 | } else { |
326 | // This is a local function which must be let-bound. |
327 | return Compound(e, ret, v); |
328 | } |
329 | } |
330 | |
331 | Expr VisitExpr_(const LetNode* l, const Var& v) final { |
332 | Expr e = GetRef<Expr>(l); |
333 | // Keep track of bound variable device types for lexically enclosing sub-expressions. |
334 | PushBoundVar(l->var, GetVirtualDevice(l->value)); |
335 | VisitExpr(l->value, l->var); |
336 | Expr ret = GetSubScope(e, 0)->let_list->Get(VisitExpr(l->body)); |
337 | // We are done with these sub-expressions. |
338 | PopBoundVar(l->var); |
339 | return Compound(e, ret, v); |
340 | } |
341 | |
342 | Expr VisitExpr_(const ConstantNode* c, const Var& v) final { |
343 | Expr e = GetRef<Expr>(c); |
344 | return Compound(e, e, v); |
345 | } |
346 | |
347 | Expr VisitExpr_(const VarNode* vn, const Var& v) final { |
348 | Expr e = GetRef<Expr>(vn); |
349 | return Atomic(e, v); |
350 | } |
351 | |
352 | Expr VisitExpr_(const GlobalVarNode* gvn, const Var& v) final { |
353 | GlobalVar gv = GetRef<GlobalVar>(gvn); |
354 | return Atomic(gv, v); |
355 | } |
356 | |
357 | Expr VisitExpr_(const OpNode* op, const Var& v) final { |
358 | Expr e = GetRef<Expr>(op); |
359 | return Atomic(e, v); |
360 | } |
361 | |
362 | Expr VisitExpr_(const ConstructorNode* c, const Var& v) final { |
363 | Expr e = GetRef<Expr>(c); |
364 | return Atomic(e, v); |
365 | } |
366 | |
367 | Expr VisitExpr_(const MatchNode* m, const Var& v) final { |
368 | Expr e = GetRef<Expr>(m); |
369 | Expr data = VisitExpr(m->data); |
370 | std::vector<Clause> clauses; |
371 | for (const Clause& c : m->clauses) { |
372 | clauses.emplace_back(c->lhs, |
373 | GetSubScope(e, 1 + clauses.size())->let_list->Get(VisitExpr(c->rhs))); |
374 | } |
375 | return Compound(e, Match(data, clauses, m->complete), v); |
376 | } |
377 | |
378 | const DependencyGraph& dg_; |
379 | NodeScopeMap* node_scope_ = nullptr; |
380 | std::unordered_map<Expr, Expr, ObjectPtrHash, ObjectPtrEqual> memo; |
381 | // a set of Expressions to include for let bindings. If set to nullptr |
382 | // all Exprs will be pushed to the let list. |
383 | ExprSet* include_set_ = nullptr; |
384 | }; |
385 | |
386 | IRModule ModuleToANormalForm(const IRModule& mod) { |
387 | tvm::Map<GlobalVar, Function> updates; |
388 | auto funcs = mod->functions; |
389 | for (const auto& it : funcs) { |
390 | ICHECK_EQ(FreeVars(it.second).size(), 0); |
391 | if (const auto* n = it.second.as<FunctionNode>()) { |
392 | if (n->GetAttr<String>(attr::kCompiler).defined()) continue; |
393 | Function func = GetRef<Function>(n); |
394 | Function ret = Downcast<Function>(transform::ToANormalForm(func)); |
395 | ICHECK_EQ(FreeVars(ret).size(), 0) << "rewritten:" << std::endl |
396 | << PrettyPrint(ret) << std::endl |
397 | << "should not have free vars: " << FreeVars(ret); |
398 | VLOG(1) << "rewritten:" << std::endl |
399 | << PrettyPrint(func) << std::endl |
400 | << "to ANF:" << std::endl |
401 | << PrettyPrint(ret); |
402 | updates.Set(it.first, ret); |
403 | } |
404 | } |
405 | |
406 | for (auto pair : updates) { |
407 | mod->Add(pair.first, pair.second, true); |
408 | } |
409 | |
410 | return mod; |
411 | } |
412 | |
413 | } // namespace |
414 | |
415 | Expr ToBasicBlockNormalFormAux(const Expr& e) { |
416 | // calculate all the dependency between nodes. |
417 | support::Arena arena; |
418 | DependencyGraph dg = DependencyGraph::Create(&arena, e); |
419 | /* The scope of the whole expr is global. |
420 | * The scope of any subexpr, is the lowest common ancestor of all incoming edge. |
421 | * We also record the set of expressions whose scope is lifted. |
422 | */ |
423 | std::pair<NodeScopeMap, ExprSet> scopes = CalcScope(dg); |
424 | return Fill::ToBasicBlockNormalForm(e, dg, &scopes.first, &scopes.second); |
425 | } |
426 | |
427 | namespace transform { |
428 | |
429 | Expr ToANormalForm(const Expr& e) { |
430 | /* When you lift a lambda, what is inside is also being lift. |
431 | * |
432 | * So we must determine the scope of the lambda before determining the scope of it's body. |
433 | * |
434 | * To make this more principled, |
435 | * we always determine the scope of parent before determining the scope of children. |
436 | * |
437 | * So we calculate all the dependency between nodes. |
438 | */ |
439 | support::Arena arena; |
440 | DependencyGraph dg = DependencyGraph::Create(&arena, e); |
441 | /* In order to model new subscopes created by lambda, if else and pattern matching, |
442 | * we also assign scope to edge as well. |
443 | * The scope of an edge is either the parent's scope, or a new subscope of the parent's scope. |
444 | * |
445 | * So, the scope of the whole expr is global. |
446 | * The scope of any subexpr, is the lowest common ancestor of all incoming edge. |
447 | * |
448 | * Every scope additionally contain a LetList which collect all value of that scope. |
449 | * We do an additional pass to fill all the LetList and we are done. |
450 | */ |
451 | std::pair<NodeScopeMap, ExprSet> scopes = CalcScope(dg); |
452 | return Fill::ToANormalForm(e, dg, &scopes.first); |
453 | } |
454 | |
455 | Pass ToANormalForm() { |
456 | runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = |
457 | [=](IRModule m, PassContext pc) { return ModuleToANormalForm(m); }; |
458 | return CreateModulePass(pass_func, 1, "ToANormalForm" , {}); |
459 | } |
460 | |
461 | TVM_REGISTER_GLOBAL("relay._transform.ToANormalForm" ).set_body_typed([]() { |
462 | return ToANormalForm(); |
463 | }); |
464 | |
465 | TVM_REGISTER_GLOBAL("relay._transform.ToANormalFormExpr" ).set_body_typed([](const Expr& e) { |
466 | return ToANormalForm(e); |
467 | }); |
468 | |
469 | } // namespace transform |
470 | |
471 | } // namespace relay |
472 | } // namespace tvm |
473 | |