1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/backend/vm/manifest_lifetimes.cc |
22 | * \brief Analysis and explicit manifestation of variable lifetimes. NOTE: the input IR should be in |
23 | * ANF and post-memory-lowering (explicit manifestation of allocations). |
24 | */ |
25 | |
26 | #include <tvm/relay/transform.h> |
27 | |
28 | #include "../../../support/arena.h" |
29 | #include "../../op/memory/device_copy.h" |
30 | #include "../../transforms/device_aware_visitors.h" |
31 | #include "../../transforms/let_list.h" |
32 | #include "../liveness_analysis.h" |
33 | |
34 | namespace tvm { |
35 | namespace relay { |
36 | namespace transform { |
37 | |
38 | /*! |
39 | * \brief Helper class to insert kills using liveness information. |
40 | */ |
41 | class KillInserter : public ExprMutator { |
42 | public: |
43 | KillInserter(const ControlFlowGraph* cfg, const LivenessAnalysis* lva) : cfg_(cfg), lva_(lva) {} |
44 | |
45 | // Limitations |
46 | // ----------- |
47 | // (1) For simplicity, we only insert kills when visiting Let bindings, and always emit the kill |
48 | // as a single subsequent binding. This is slightly inaccurate; for example, if the condition of |
49 | // an If is dead after the test, we can immediately kill the condition in each branch: |
50 | // let %x = if (%dead_cond) { |
51 | // let %_0 = memory.kill(%dead_cond); |
52 | // ... |
53 | // } else { |
54 | // let %_1 = memory.kill(%dead_cond); |
55 | // ... |
56 | // } |
57 | // as opposed to: |
58 | // let %x = if (%dead_cond) ... |
59 | // let %_0 = memory.kill(%dead_cond); |
60 | // |
61 | // (2) Killed variables are calculated as live in - live out, which misses variables that are |
62 | // actually dead but not in a live-in set. Example: |
63 | // @f(%x: int, %y: int, %c: bool) { |
64 | // let %w = if (%c) { |
65 | // let %z = %y + %y; |
66 | // %z |
67 | // } else { |
68 | // %y |
69 | // }; |
70 | // %w |
71 | // } |
72 | // After inserting kills: |
73 | // @f(%x: int, %y: int, %c: bool) { |
74 | // /* %x is always dead, so never in any live in or live out set */ |
75 | // let %w = if (%c) { |
76 | // let %z = %y + %y; |
77 | // let %_0 = memory.kill(%y); |
78 | // %z |
79 | // } else { |
80 | // %y |
81 | // /* %y is dead at this point */ |
82 | // }; |
83 | // let %_1 = memory.kill(%c); |
84 | // /* no kill for %y since it's not in the live-in of %w AND %w isn't a let binding */ |
85 | // %w |
86 | // } |
87 | // |
88 | // (3) When the result expr of an If branch is a variable, and this expr is the last use of the |
89 | // var, we cannot "kill" the var since it is being returned. The VM compiler also emits a Move |
90 | // instruction to merge the branch results, which creates another ObjectRef to the Object held |
91 | // by the var. The var is also not in the subsequent live-in (since it is indeed dead by this |
92 | // point), so it won't be killed. An example can be seen in the previous code block for (2), where |
93 | // %y is not killed if the else-branch is taken (and indeed it can be killed, as %w is mapped to |
94 | // a new register and holds a fresh reference to the object referenced by %y). |
95 | // |
96 | // However, these limitations are unlikely to cause large leaks in practice. |
97 | |
98 | Expr VisitExpr_(const LetNode* let_node) override { |
99 | Expr expr = GetRef<Expr>(let_node); |
100 | LetList ll; |
101 | |
102 | while (const LetNode* inner_let_node = expr.as<LetNode>()) { |
103 | ll.Push(inner_let_node->var, VisitExpr(inner_let_node->value)); |
104 | |
105 | ICHECK(!inner_let_node->value.as<VarNode>()) << "aliasing should have been eliminated." ; |
106 | ICHECK(cfg_->let_map.count(expr)) << "all Let exprs should be mapped in the CFG" ; |
107 | |
108 | const ControlFlowGraph::NodePtr n = cfg_->let_map.at(expr); |
109 | |
110 | const VarSet& li = lva_->live_in.at(n); |
111 | const VarSet& lo = lva_->live_out.at(n); |
112 | |
113 | // Killed vars = live in - live out. |
114 | VarSet kills; |
115 | for (const Var& v : li) { |
116 | if (!lo.count(v)) { |
117 | kills.insert(v); |
118 | } |
119 | } |
120 | |
121 | for (const Var& v : kills) { |
122 | ll.Push(Call(Op::Get("memory.kill" ), {v})); |
123 | } |
124 | |
125 | expr = inner_let_node->body; |
126 | } |
127 | |
128 | return ll.Get(VisitExpr(expr)); |
129 | } |
130 | |
131 | private: |
132 | const ControlFlowGraph* cfg_; |
133 | const LivenessAnalysis* lva_; |
134 | }; |
135 | |
136 | /*! |
137 | * \brief Helper class to eliminate variable aliasing. This pass anticipates the VM compiler's |
138 | * register aliasing behavior so as to avoid killing vars that point to the same register. An |
139 | * alternative approach would be to track aliasing within the VM compiler itself, so that kill |
140 | * instructions are only emitted when all aliases are killed. |
141 | */ |
142 | class AliasEliminator : public MixedModeMutator { |
143 | public: |
144 | using MixedModeMutator::VisitExpr_; |
145 | |
146 | Expr VisitExpr_(const LetNode* let_node) override { |
147 | Expr expr = GetRef<Expr>(let_node); |
148 | LetList ll; |
149 | std::vector<Var> aliased_vars; |
150 | |
151 | while (const LetNode* inner_let_node = expr.as<LetNode>()) { |
152 | const Var& var = inner_let_node->var; |
153 | const Expr& val = inner_let_node->value; |
154 | bool aliased = false; |
155 | ICHECK(!alias_.count(var)); |
156 | |
157 | if (const VarNode* alias_of_n = AsIgnoringOnDevice<VarNode>(val)) { |
158 | alias_[var] = Downcast<Var>(VisitExpr_(alias_of_n)); |
159 | aliased = true; |
160 | } else if (AsIgnoringOnDevice<CallNode>(val)) { |
161 | // Copying to the same device is aliasing. |
162 | // WARNING: this must be kept in sync with the VM compiler logic in |
163 | // src/relay/backend/vm/compiler.cc, line 541, in DeviceAwareVisitExpr_(const CallNode*). |
164 | Expr unwrapped = IgnoreOnDevice(val); |
165 | DeviceCopyProps copy_props = GetDeviceCopyProps(unwrapped); |
166 | if (copy_props.body.defined()) { |
167 | if (copy_props.src_virtual_device->device_type() == |
168 | copy_props.dst_virtual_device->device_type() && |
169 | copy_props.src_virtual_device->virtual_device_id == |
170 | copy_props.dst_virtual_device->virtual_device_id) { |
171 | Expr to_copy = Downcast<Call>(unwrapped)->args[0]; |
172 | if (const VarNode* alias_of_n = to_copy.as<VarNode>()) { |
173 | alias_[var] = Downcast<Var>(VisitExpr_(alias_of_n)); |
174 | aliased = true; |
175 | } |
176 | } |
177 | } |
178 | } |
179 | |
180 | if (!aliased) { |
181 | ll.Push(var, VisitExpr(val)); |
182 | } else { |
183 | aliased_vars.push_back(var); |
184 | } |
185 | |
186 | expr = inner_let_node->body; |
187 | } |
188 | |
189 | Expr body = ll.Get(VisitExpr(expr)); |
190 | |
191 | // remove the aliased vars so that alias_ only tracks things in scope |
192 | for (const Var& v : aliased_vars) { |
193 | alias_.erase(v); |
194 | } |
195 | |
196 | return body; |
197 | } |
198 | |
199 | Expr VisitExpr_(const VarNode* var_node) override { |
200 | Var var = GetRef<Var>(var_node); |
201 | if (alias_.count(var)) { |
202 | return alias_[var]; |
203 | } |
204 | return std::move(var); |
205 | } |
206 | |
207 | Expr VisitExpr_(const FunctionNode* func_node) override { |
208 | Expr new_body = VisitExpr(func_node->body); |
209 | return WithFields(GetRef<Function>(func_node), /*opt_params=*/NullOpt, /*opt_body=*/new_body); |
210 | } |
211 | |
212 | // The only register-level aliasing that occurs in Match expressions is when |
213 | // the deconstructed expression is a Var, and the matched pattern is also a Var. |
214 | Expr VisitExpr_(const MatchNode* match_node) override { |
215 | if (const VarNode* data_var_node = AsIgnoringOnDevice<VarNode>(match_node->data)) { |
216 | Var data_var = Downcast<Var>(VisitExpr_(data_var_node)); |
217 | std::vector<Clause> new_clauses; |
218 | for (const Clause& clause : match_node->clauses) { |
219 | const PatternVarNode* pv_node = nullptr; |
220 | if ((pv_node = clause->lhs.as<PatternVarNode>())) { |
221 | alias_[pv_node->var] = data_var; |
222 | } |
223 | new_clauses.push_back(Clause(clause->lhs, VisitExpr(clause->rhs))); |
224 | if (pv_node) { |
225 | alias_.erase(pv_node->var); |
226 | } |
227 | } |
228 | return Match(data_var, new_clauses, match_node->complete, match_node->span); |
229 | } else { |
230 | return ExprMutator::VisitExpr_(match_node); |
231 | } |
232 | } |
233 | |
234 | private: |
235 | /*! |
236 | * \brief Mapping of var -> var it's an alias of. Note that transitive aliases |
237 | * (e.g. x = 0; y = x; z = y) are mapped to the non-aliased variable (in this example "x"). |
238 | */ |
239 | std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual> alias_; |
240 | }; |
241 | |
242 | Pass ManifestLifetimes() { |
243 | auto pass_func = [](Function f, IRModule m, PassContext pc) -> Function { |
244 | f = Downcast<Function>(AliasEliminator().Mutate(f)); |
245 | Arena arena; |
246 | ControlFlowGraph cfg = ControlFlowGraph::Create(&arena, f); |
247 | UseDefAnalysis use_def = UseDefAnalysis::Analyze(cfg); |
248 | LivenessAnalysis lva = LivenessAnalysis::Analyze(cfg, use_def); |
249 | KillInserter ki(&cfg, &lva); |
250 | Function nf = Downcast<Function>(ki.Mutate(f)); |
251 | return nf; |
252 | }; |
253 | return CreateFunctionPass(pass_func, 0, "ManifestLifetimes" , {}); |
254 | } |
255 | |
256 | TVM_REGISTER_GLOBAL("relay._transform.ManifestLifetimes" ).set_body_typed(ManifestLifetimes); |
257 | |
258 | } // namespace transform |
259 | } // namespace relay |
260 | } // namespace tvm |
261 | |