1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/backend/liveness_analysis.cc
22 * \brief Analysis that collects the live variables before and after each node.
23 * NOTE: the input IR should be in ANF.
24 */
25
26#include "./liveness_analysis.h"
27
28#include <list>
29#include <unordered_set>
30#include <utility>
31#include <vector>
32
33namespace tvm {
34namespace relay {
35namespace transform {
36
37using support::Arena;
38using VarSet = std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>;
39
40ControlFlowGraph ControlFlowGraph::Create(Arena* arena, const Expr& body) {
41 return Creator().Create(arena, body);
42}
43
44ControlFlowGraph ControlFlowGraph::Creator::Create(Arena* arena, const Expr& body) {
45 arena_ = arena;
46 cfg_.entry = BasicBlock::Make(arena);
47 VisitExpr(body, cfg_.entry);
48 return std::move(cfg_);
49}
50
51void ControlFlowGraph::Creator::Succ(BasicBlockPtr from, BasicBlockPtr to) {
52 from->succ.push_back(to);
53 to->pred.push_back(from);
54}
55
56void ControlFlowGraph::Creator::VisitExpr_(const FunctionNode* f, BasicBlockPtr parent) {
57 ICHECK(!in_func_) << "nested functions not supported by CFG analysis";
58 in_func_ = true;
59
60 // Unwrap the nested function and proceed normally.
61 if (f->HasNonzeroAttr(attr::kClosure)) {
62 ICHECK(f->body.as<FunctionNode>());
63 return VisitExpr(Downcast<Function>(f->body)->body, parent);
64 }
65
66 return VisitExpr(f->body, parent);
67}
68
69void ControlFlowGraph::Creator::VisitExpr_(const LetNode* let_node, BasicBlockPtr parent) {
70 Expr expr = GetRef<Expr>(let_node);
71
72 while (const LetNode* inner_let_node = expr.as<LetNode>()) {
73 NodePtr curr_node = Node::Make(arena_, parent, expr);
74
75 ICHECK(!cfg_.let_map.count(expr));
76 cfg_.let_map[expr] = curr_node;
77 cfg_.reverse_post_order.push_back(curr_node);
78
79 // The basic block ends upon reaching control flow, with successor blocks corresponding to the
80 // control flow branch exprs (true/false in If, and one for each clause in Match).
81 if (const IfNode* ite = AsIgnoringOnDevice<IfNode>(inner_let_node->value)) {
82 // Create the basic blocks for each branch and mark them as successors to the current block.
83 BasicBlockPtr t_block = BasicBlock::Make(arena_);
84 BasicBlockPtr f_block = BasicBlock::Make(arena_);
85 Succ(parent, t_block);
86 Succ(parent, f_block);
87
88 VisitExpr(ite->true_branch, t_block);
89 VisitExpr(ite->false_branch, f_block);
90
91 // All subsequent bindings (and/or the body expr) will be in a new basic block.
92 BasicBlockPtr next = BasicBlock::Make(arena_);
93 Succ(t_block, next);
94 Succ(f_block, next);
95 parent = next;
96 } else if (const MatchNode* match = AsIgnoringOnDevice<MatchNode>(inner_let_node->value)) {
97 // Same as above but one for each pattern.
98 std::vector<BasicBlockPtr> clause_blocks;
99 BasicBlockPtr next = BasicBlock::Make(arena_);
100 for (const Clause& clause : match->clauses) {
101 BasicBlockPtr clause_block = BasicBlock::Make(arena_);
102 Succ(parent, clause_block);
103 Succ(clause_block, next);
104 VisitExpr(clause->rhs, clause_block);
105 }
106 parent = next;
107 }
108
109 expr = inner_let_node->body;
110 }
111
112 VisitExpr(expr, parent);
113}
114
115void ControlFlowGraph::Creator::VisitExpr_(const IfNode* if_node, BasicBlockPtr parent) {
116 // TODO(@altanh): is there a way of making this work?
117 LOG(FATAL) << "If expressions should be bound to variables.";
118}
119
120void ControlFlowGraph::Creator::VisitExpr_(const MatchNode* match_node, BasicBlockPtr parent) {
121 // TODO(@altanh): same as If
122 LOG(FATAL) << "Match expressions should be bound to variables.";
123}
124
125VarSet VarUseCollector::VisitExpr_(const VarNode* var_node) { return {GetRef<Var>(var_node)}; }
126
127VarSet VarUseCollector::VisitExpr_(const CallNode* call_node) {
128 VarSet use = VisitExpr(call_node->op);
129 for (const Expr& arg : call_node->args) {
130 VarSet arg_use = VisitExpr(arg);
131 use.insert(arg_use.begin(), arg_use.end());
132 }
133 return use;
134}
135
136VarSet VarUseCollector::VisitExpr_(const TupleNode* tuple_node) {
137 VarSet use;
138 for (const Expr& field : tuple_node->fields) {
139 VarSet field_use = VisitExpr(field);
140 use.insert(field_use.begin(), field_use.end());
141 }
142 return use;
143}
144
145VarSet VarUseCollector::VisitExpr_(const TupleGetItemNode* get_node) {
146 return VisitExpr(get_node->tuple);
147}
148
149VarSet VarUseCollector::VisitExpr_(const IfNode* if_node) { return VisitExpr(if_node->cond); }
150
151VarSet VarUseCollector::VisitExpr_(const MatchNode* match_node) {
152 return VisitExpr(match_node->data);
153}
154
155UseDefAnalysis UseDefAnalysis::Analyze(const CFG& cfg) {
156 UseDefAnalysis a;
157
158 // One pass is sufficient.
159 for (auto it = cfg.reverse_post_order.begin(); it != cfg.reverse_post_order.end(); ++it) {
160 const CFG::NodePtr& node = *it;
161 if (const LetNode* let_node = AsIgnoringOnDevice<LetNode>(node->expr)) {
162 a.use[node] = a.use_collector.VisitExpr(let_node->value);
163 a.def[node] = let_node->var;
164 } else {
165 a.use[node] = a.use_collector.VisitExpr(node->expr);
166 a.def[node] = Var();
167 }
168 }
169
170 return a;
171}
172
173bool SetEqual(const VarSet& a, const VarSet& b) {
174 if (a.size() != b.size()) {
175 return false;
176 }
177 for (auto& xa : a) {
178 if (!b.count(xa)) {
179 return false;
180 }
181 }
182 return true;
183}
184
185LivenessAnalysis LivenessAnalysis::Analyze(const ControlFlowGraph& cfg,
186 const UseDefAnalysis& use_def) {
187 LivenessAnalysis a;
188 std::list<CFG::NodePtr> worklist;
189
190 // Initialize worklist to post-order traversal for quick convergence.
191 worklist.insert(worklist.end(), cfg.reverse_post_order.rbegin(), cfg.reverse_post_order.rend());
192
193 // See https://lambda.uta.edu/cse5317/notes/node40.html for an overview of the algorithm.
194 auto visitor = [&](const CFG::NodePtr n) {
195 VarSet old_in_n = a.live_in[n];
196 VarSet old_out_n = a.live_out[n];
197
198 a.live_in[n] = use_def.use.at(n);
199 for (const Var& v : a.live_out[n]) {
200 if (!v.same_as(use_def.def.at(n))) {
201 a.live_in[n].insert(v);
202 }
203 }
204
205 a.live_out[n] = VarSet();
206 for (const CFG::NodePtr& s : n->GetSucc()) {
207 a.live_out[n].insert(a.live_in[s].begin(), a.live_in[s].end());
208 }
209
210 if (SetEqual(old_in_n, a.live_in[n]) && SetEqual(old_out_n, a.live_out[n])) {
211 // No need to update the worklist.
212 } else {
213 // Add predecessor nodes back to worklist (no need to add successors, since each node's
214 // in/out sets are not dependent on its predecessors).
215 for (const CFG::NodePtr& p : n->GetPred()) {
216 worklist.push_back(p);
217 }
218 }
219 };
220
221 while (!worklist.empty()) {
222 const CFG::NodePtr n = worklist.front();
223 worklist.pop_front();
224 visitor(n);
225 }
226
227 return a;
228}
229
230} // namespace transform
231} // namespace relay
232} // namespace tvm
233