1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/backend/vm/manifest_lifetimes.cc
22 * \brief Analysis and explicit manifestation of variable lifetimes. NOTE: the input IR should be in
23 * ANF and post-memory-lowering (explicit manifestation of allocations).
24 */
25
26#include <tvm/relay/transform.h>
27
28#include "../../../support/arena.h"
29#include "../../op/memory/device_copy.h"
30#include "../../transforms/device_aware_visitors.h"
31#include "../../transforms/let_list.h"
32#include "../liveness_analysis.h"
33
34namespace tvm {
35namespace relay {
36namespace transform {
37
38/*!
39 * \brief Helper class to insert kills using liveness information.
40 */
41class KillInserter : public ExprMutator {
42 public:
43 KillInserter(const ControlFlowGraph* cfg, const LivenessAnalysis* lva) : cfg_(cfg), lva_(lva) {}
44
45 // Limitations
46 // -----------
47 // (1) For simplicity, we only insert kills when visiting Let bindings, and always emit the kill
48 // as a single subsequent binding. This is slightly inaccurate; for example, if the condition of
49 // an If is dead after the test, we can immediately kill the condition in each branch:
50 // let %x = if (%dead_cond) {
51 // let %_0 = memory.kill(%dead_cond);
52 // ...
53 // } else {
54 // let %_1 = memory.kill(%dead_cond);
55 // ...
56 // }
57 // as opposed to:
58 // let %x = if (%dead_cond) ...
59 // let %_0 = memory.kill(%dead_cond);
60 //
61 // (2) Killed variables are calculated as live in - live out, which misses variables that are
62 // actually dead but not in a live-in set. Example:
63 // @f(%x: int, %y: int, %c: bool) {
64 // let %w = if (%c) {
65 // let %z = %y + %y;
66 // %z
67 // } else {
68 // %y
69 // };
70 // %w
71 // }
72 // After inserting kills:
73 // @f(%x: int, %y: int, %c: bool) {
74 // /* %x is always dead, so never in any live in or live out set */
75 // let %w = if (%c) {
76 // let %z = %y + %y;
77 // let %_0 = memory.kill(%y);
78 // %z
79 // } else {
80 // %y
81 // /* %y is dead at this point */
82 // };
83 // let %_1 = memory.kill(%c);
84 // /* no kill for %y since it's not in the live-in of %w AND %w isn't a let binding */
85 // %w
86 // }
87 //
88 // (3) When the result expr of an If branch is a variable, and this expr is the last use of the
89 // var, we cannot "kill" the var since it is being returned. The VM compiler also emits a Move
90 // instruction to merge the branch results, which creates another ObjectRef to the Object held
91 // by the var. The var is also not in the subsequent live-in (since it is indeed dead by this
92 // point), so it won't be killed. An example can be seen in the previous code block for (2), where
93 // %y is not killed if the else-branch is taken (and indeed it can be killed, as %w is mapped to
94 // a new register and holds a fresh reference to the object referenced by %y).
95 //
96 // However, these limitations are unlikely to cause large leaks in practice.
97
98 Expr VisitExpr_(const LetNode* let_node) override {
99 Expr expr = GetRef<Expr>(let_node);
100 LetList ll;
101
102 while (const LetNode* inner_let_node = expr.as<LetNode>()) {
103 ll.Push(inner_let_node->var, VisitExpr(inner_let_node->value));
104
105 ICHECK(!inner_let_node->value.as<VarNode>()) << "aliasing should have been eliminated.";
106 ICHECK(cfg_->let_map.count(expr)) << "all Let exprs should be mapped in the CFG";
107
108 const ControlFlowGraph::NodePtr n = cfg_->let_map.at(expr);
109
110 const VarSet& li = lva_->live_in.at(n);
111 const VarSet& lo = lva_->live_out.at(n);
112
113 // Killed vars = live in - live out.
114 VarSet kills;
115 for (const Var& v : li) {
116 if (!lo.count(v)) {
117 kills.insert(v);
118 }
119 }
120
121 for (const Var& v : kills) {
122 ll.Push(Call(Op::Get("memory.kill"), {v}));
123 }
124
125 expr = inner_let_node->body;
126 }
127
128 return ll.Get(VisitExpr(expr));
129 }
130
131 private:
132 const ControlFlowGraph* cfg_;
133 const LivenessAnalysis* lva_;
134};
135
136/*!
137 * \brief Helper class to eliminate variable aliasing. This pass anticipates the VM compiler's
138 * register aliasing behavior so as to avoid killing vars that point to the same register. An
139 * alternative approach would be to track aliasing within the VM compiler itself, so that kill
140 * instructions are only emitted when all aliases are killed.
141 */
142class AliasEliminator : public MixedModeMutator {
143 public:
144 using MixedModeMutator::VisitExpr_;
145
146 Expr VisitExpr_(const LetNode* let_node) override {
147 Expr expr = GetRef<Expr>(let_node);
148 LetList ll;
149 std::vector<Var> aliased_vars;
150
151 while (const LetNode* inner_let_node = expr.as<LetNode>()) {
152 const Var& var = inner_let_node->var;
153 const Expr& val = inner_let_node->value;
154 bool aliased = false;
155 ICHECK(!alias_.count(var));
156
157 if (const VarNode* alias_of_n = AsIgnoringOnDevice<VarNode>(val)) {
158 alias_[var] = Downcast<Var>(VisitExpr_(alias_of_n));
159 aliased = true;
160 } else if (AsIgnoringOnDevice<CallNode>(val)) {
161 // Copying to the same device is aliasing.
162 // WARNING: this must be kept in sync with the VM compiler logic in
163 // src/relay/backend/vm/compiler.cc, line 541, in DeviceAwareVisitExpr_(const CallNode*).
164 Expr unwrapped = IgnoreOnDevice(val);
165 DeviceCopyProps copy_props = GetDeviceCopyProps(unwrapped);
166 if (copy_props.body.defined()) {
167 if (copy_props.src_virtual_device->device_type() ==
168 copy_props.dst_virtual_device->device_type() &&
169 copy_props.src_virtual_device->virtual_device_id ==
170 copy_props.dst_virtual_device->virtual_device_id) {
171 Expr to_copy = Downcast<Call>(unwrapped)->args[0];
172 if (const VarNode* alias_of_n = to_copy.as<VarNode>()) {
173 alias_[var] = Downcast<Var>(VisitExpr_(alias_of_n));
174 aliased = true;
175 }
176 }
177 }
178 }
179
180 if (!aliased) {
181 ll.Push(var, VisitExpr(val));
182 } else {
183 aliased_vars.push_back(var);
184 }
185
186 expr = inner_let_node->body;
187 }
188
189 Expr body = ll.Get(VisitExpr(expr));
190
191 // remove the aliased vars so that alias_ only tracks things in scope
192 for (const Var& v : aliased_vars) {
193 alias_.erase(v);
194 }
195
196 return body;
197 }
198
199 Expr VisitExpr_(const VarNode* var_node) override {
200 Var var = GetRef<Var>(var_node);
201 if (alias_.count(var)) {
202 return alias_[var];
203 }
204 return std::move(var);
205 }
206
207 Expr VisitExpr_(const FunctionNode* func_node) override {
208 Expr new_body = VisitExpr(func_node->body);
209 return WithFields(GetRef<Function>(func_node), /*opt_params=*/NullOpt, /*opt_body=*/new_body);
210 }
211
212 // The only register-level aliasing that occurs in Match expressions is when
213 // the deconstructed expression is a Var, and the matched pattern is also a Var.
214 Expr VisitExpr_(const MatchNode* match_node) override {
215 if (const VarNode* data_var_node = AsIgnoringOnDevice<VarNode>(match_node->data)) {
216 Var data_var = Downcast<Var>(VisitExpr_(data_var_node));
217 std::vector<Clause> new_clauses;
218 for (const Clause& clause : match_node->clauses) {
219 const PatternVarNode* pv_node = nullptr;
220 if ((pv_node = clause->lhs.as<PatternVarNode>())) {
221 alias_[pv_node->var] = data_var;
222 }
223 new_clauses.push_back(Clause(clause->lhs, VisitExpr(clause->rhs)));
224 if (pv_node) {
225 alias_.erase(pv_node->var);
226 }
227 }
228 return Match(data_var, new_clauses, match_node->complete, match_node->span);
229 } else {
230 return ExprMutator::VisitExpr_(match_node);
231 }
232 }
233
234 private:
235 /*!
236 * \brief Mapping of var -> var it's an alias of. Note that transitive aliases
237 * (e.g. x = 0; y = x; z = y) are mapped to the non-aliased variable (in this example "x").
238 */
239 std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual> alias_;
240};
241
242Pass ManifestLifetimes() {
243 auto pass_func = [](Function f, IRModule m, PassContext pc) -> Function {
244 f = Downcast<Function>(AliasEliminator().Mutate(f));
245 Arena arena;
246 ControlFlowGraph cfg = ControlFlowGraph::Create(&arena, f);
247 UseDefAnalysis use_def = UseDefAnalysis::Analyze(cfg);
248 LivenessAnalysis lva = LivenessAnalysis::Analyze(cfg, use_def);
249 KillInserter ki(&cfg, &lva);
250 Function nf = Downcast<Function>(ki.Mutate(f));
251 return nf;
252 };
253 return CreateFunctionPass(pass_func, 0, "ManifestLifetimes", {});
254}
255
256TVM_REGISTER_GLOBAL("relay._transform.ManifestLifetimes").set_body_typed(ManifestLifetimes);
257
258} // namespace transform
259} // namespace relay
260} // namespace tvm
261