1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/backend/liveness_analysis.cc |
22 | * \brief Analysis that collects the live variables before and after each node. |
23 | * NOTE: the input IR should be in ANF. |
24 | */ |
25 | |
26 | #include "./liveness_analysis.h" |
27 | |
28 | #include <list> |
29 | #include <unordered_set> |
30 | #include <utility> |
31 | #include <vector> |
32 | |
33 | namespace tvm { |
34 | namespace relay { |
35 | namespace transform { |
36 | |
37 | using support::Arena; |
38 | using VarSet = std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>; |
39 | |
40 | ControlFlowGraph ControlFlowGraph::Create(Arena* arena, const Expr& body) { |
41 | return Creator().Create(arena, body); |
42 | } |
43 | |
44 | ControlFlowGraph ControlFlowGraph::Creator::Create(Arena* arena, const Expr& body) { |
45 | arena_ = arena; |
46 | cfg_.entry = BasicBlock::Make(arena); |
47 | VisitExpr(body, cfg_.entry); |
48 | return std::move(cfg_); |
49 | } |
50 | |
51 | void ControlFlowGraph::Creator::Succ(BasicBlockPtr from, BasicBlockPtr to) { |
52 | from->succ.push_back(to); |
53 | to->pred.push_back(from); |
54 | } |
55 | |
56 | void ControlFlowGraph::Creator::VisitExpr_(const FunctionNode* f, BasicBlockPtr parent) { |
57 | ICHECK(!in_func_) << "nested functions not supported by CFG analysis" ; |
58 | in_func_ = true; |
59 | |
60 | // Unwrap the nested function and proceed normally. |
61 | if (f->HasNonzeroAttr(attr::kClosure)) { |
62 | ICHECK(f->body.as<FunctionNode>()); |
63 | return VisitExpr(Downcast<Function>(f->body)->body, parent); |
64 | } |
65 | |
66 | return VisitExpr(f->body, parent); |
67 | } |
68 | |
69 | void ControlFlowGraph::Creator::VisitExpr_(const LetNode* let_node, BasicBlockPtr parent) { |
70 | Expr expr = GetRef<Expr>(let_node); |
71 | |
72 | while (const LetNode* inner_let_node = expr.as<LetNode>()) { |
73 | NodePtr curr_node = Node::Make(arena_, parent, expr); |
74 | |
75 | ICHECK(!cfg_.let_map.count(expr)); |
76 | cfg_.let_map[expr] = curr_node; |
77 | cfg_.reverse_post_order.push_back(curr_node); |
78 | |
79 | // The basic block ends upon reaching control flow, with successor blocks corresponding to the |
80 | // control flow branch exprs (true/false in If, and one for each clause in Match). |
81 | if (const IfNode* ite = AsIgnoringOnDevice<IfNode>(inner_let_node->value)) { |
82 | // Create the basic blocks for each branch and mark them as successors to the current block. |
83 | BasicBlockPtr t_block = BasicBlock::Make(arena_); |
84 | BasicBlockPtr f_block = BasicBlock::Make(arena_); |
85 | Succ(parent, t_block); |
86 | Succ(parent, f_block); |
87 | |
88 | VisitExpr(ite->true_branch, t_block); |
89 | VisitExpr(ite->false_branch, f_block); |
90 | |
91 | // All subsequent bindings (and/or the body expr) will be in a new basic block. |
92 | BasicBlockPtr next = BasicBlock::Make(arena_); |
93 | Succ(t_block, next); |
94 | Succ(f_block, next); |
95 | parent = next; |
96 | } else if (const MatchNode* match = AsIgnoringOnDevice<MatchNode>(inner_let_node->value)) { |
97 | // Same as above but one for each pattern. |
98 | std::vector<BasicBlockPtr> clause_blocks; |
99 | BasicBlockPtr next = BasicBlock::Make(arena_); |
100 | for (const Clause& clause : match->clauses) { |
101 | BasicBlockPtr clause_block = BasicBlock::Make(arena_); |
102 | Succ(parent, clause_block); |
103 | Succ(clause_block, next); |
104 | VisitExpr(clause->rhs, clause_block); |
105 | } |
106 | parent = next; |
107 | } |
108 | |
109 | expr = inner_let_node->body; |
110 | } |
111 | |
112 | VisitExpr(expr, parent); |
113 | } |
114 | |
115 | void ControlFlowGraph::Creator::VisitExpr_(const IfNode* if_node, BasicBlockPtr parent) { |
116 | // TODO(@altanh): is there a way of making this work? |
117 | LOG(FATAL) << "If expressions should be bound to variables." ; |
118 | } |
119 | |
120 | void ControlFlowGraph::Creator::VisitExpr_(const MatchNode* match_node, BasicBlockPtr parent) { |
121 | // TODO(@altanh): same as If |
122 | LOG(FATAL) << "Match expressions should be bound to variables." ; |
123 | } |
124 | |
125 | VarSet VarUseCollector::VisitExpr_(const VarNode* var_node) { return {GetRef<Var>(var_node)}; } |
126 | |
127 | VarSet VarUseCollector::VisitExpr_(const CallNode* call_node) { |
128 | VarSet use = VisitExpr(call_node->op); |
129 | for (const Expr& arg : call_node->args) { |
130 | VarSet arg_use = VisitExpr(arg); |
131 | use.insert(arg_use.begin(), arg_use.end()); |
132 | } |
133 | return use; |
134 | } |
135 | |
136 | VarSet VarUseCollector::VisitExpr_(const TupleNode* tuple_node) { |
137 | VarSet use; |
138 | for (const Expr& field : tuple_node->fields) { |
139 | VarSet field_use = VisitExpr(field); |
140 | use.insert(field_use.begin(), field_use.end()); |
141 | } |
142 | return use; |
143 | } |
144 | |
145 | VarSet VarUseCollector::VisitExpr_(const TupleGetItemNode* get_node) { |
146 | return VisitExpr(get_node->tuple); |
147 | } |
148 | |
149 | VarSet VarUseCollector::VisitExpr_(const IfNode* if_node) { return VisitExpr(if_node->cond); } |
150 | |
151 | VarSet VarUseCollector::VisitExpr_(const MatchNode* match_node) { |
152 | return VisitExpr(match_node->data); |
153 | } |
154 | |
155 | UseDefAnalysis UseDefAnalysis::Analyze(const CFG& cfg) { |
156 | UseDefAnalysis a; |
157 | |
158 | // One pass is sufficient. |
159 | for (auto it = cfg.reverse_post_order.begin(); it != cfg.reverse_post_order.end(); ++it) { |
160 | const CFG::NodePtr& node = *it; |
161 | if (const LetNode* let_node = AsIgnoringOnDevice<LetNode>(node->expr)) { |
162 | a.use[node] = a.use_collector.VisitExpr(let_node->value); |
163 | a.def[node] = let_node->var; |
164 | } else { |
165 | a.use[node] = a.use_collector.VisitExpr(node->expr); |
166 | a.def[node] = Var(); |
167 | } |
168 | } |
169 | |
170 | return a; |
171 | } |
172 | |
173 | bool SetEqual(const VarSet& a, const VarSet& b) { |
174 | if (a.size() != b.size()) { |
175 | return false; |
176 | } |
177 | for (auto& xa : a) { |
178 | if (!b.count(xa)) { |
179 | return false; |
180 | } |
181 | } |
182 | return true; |
183 | } |
184 | |
185 | LivenessAnalysis LivenessAnalysis::Analyze(const ControlFlowGraph& cfg, |
186 | const UseDefAnalysis& use_def) { |
187 | LivenessAnalysis a; |
188 | std::list<CFG::NodePtr> worklist; |
189 | |
190 | // Initialize worklist to post-order traversal for quick convergence. |
191 | worklist.insert(worklist.end(), cfg.reverse_post_order.rbegin(), cfg.reverse_post_order.rend()); |
192 | |
193 | // See https://lambda.uta.edu/cse5317/notes/node40.html for an overview of the algorithm. |
194 | auto visitor = [&](const CFG::NodePtr n) { |
195 | VarSet old_in_n = a.live_in[n]; |
196 | VarSet old_out_n = a.live_out[n]; |
197 | |
198 | a.live_in[n] = use_def.use.at(n); |
199 | for (const Var& v : a.live_out[n]) { |
200 | if (!v.same_as(use_def.def.at(n))) { |
201 | a.live_in[n].insert(v); |
202 | } |
203 | } |
204 | |
205 | a.live_out[n] = VarSet(); |
206 | for (const CFG::NodePtr& s : n->GetSucc()) { |
207 | a.live_out[n].insert(a.live_in[s].begin(), a.live_in[s].end()); |
208 | } |
209 | |
210 | if (SetEqual(old_in_n, a.live_in[n]) && SetEqual(old_out_n, a.live_out[n])) { |
211 | // No need to update the worklist. |
212 | } else { |
213 | // Add predecessor nodes back to worklist (no need to add successors, since each node's |
214 | // in/out sets are not dependent on its predecessors). |
215 | for (const CFG::NodePtr& p : n->GetPred()) { |
216 | worklist.push_back(p); |
217 | } |
218 | } |
219 | }; |
220 | |
221 | while (!worklist.empty()) { |
222 | const CFG::NodePtr n = worklist.front(); |
223 | worklist.pop_front(); |
224 | visitor(n); |
225 | } |
226 | |
227 | return a; |
228 | } |
229 | |
230 | } // namespace transform |
231 | } // namespace relay |
232 | } // namespace tvm |
233 | |