1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/analysis/call_graph.cc
22 * \brief Implementation of APIs to handle the call graph of a Relay module.
23 */
24
25#include "call_graph.h"
26
27#include <tvm/relay/attrs/annotation.h>
28#include <tvm/relay/expr_functor.h>
29#include <tvm/runtime/object.h>
30
31#include <algorithm>
32#include <memory>
33#include <sstream>
34#include <unordered_set>
35#include <vector>
36
37#include "../op/call/call.h"
38
39namespace tvm {
40namespace relay {
41
42CallGraph::CallGraph(IRModule module) {
43 auto n = make_object<CallGraphNode>();
44 n->module = std::move(module);
45 auto gvar_funcs = n->module->functions;
46 for (const auto& it : gvar_funcs) {
47 if (const auto* fn = it.second.as<FunctionNode>()) {
48 auto func = GetRef<Function>(fn);
49 // Add the global function to gradually build up the call graph.
50 n->AddToCallGraph(it.first, func);
51 }
52 }
53 data_ = std::move(n);
54}
55
56void CallGraphNode::AddToCallGraph(const GlobalVar& gv, const Function& func) {
57 ICHECK(func.defined() && gv.defined());
58 // Add the current global function as an entry to the call grpah.
59 CallGraphEntry* cg_node = LookupGlobalVar(gv);
60
61 // Only GlobalVar nodes need to be handled in a function. It indicates that
62 // the global function of a callee is called by the function that is being
63 // processed. An edge will be added from the current global function, cg_node,
64 // to the node that contains the found callee GlobalVarNode.
65 //
66 // This is the major overhead for constructing a call graph because the
67 // post-order visitor will visit each AST node of the current function to
68 // figure out the dependencies between functions.
69 PostOrderVisit(func, [&](const Expr& expr) {
70 // TODO(mbs): Cleanup shapes functions.
71 if (const auto* call_node = expr.as<CallNode>()) {
72 CallLoweredProps props = GetCallLoweredProps(call_node);
73 if (props.lowered_func.defined() && props.attrs.metadata.count("prim_shape_fn_var")) {
74 // We are implicitly calling the shape function *in addition to* the call target.
75 CallGraphEntry* callee_cg_node =
76 LookupGlobalVar(Downcast<GlobalVar>(props.attrs.metadata["prim_shape_fn_var"]));
77 cg_node->AddCalledGlobal(callee_cg_node);
78 }
79 } else if (const auto* global_var_node = expr.as<GlobalVarNode>()) {
80 auto callee = GetRef<GlobalVar>(global_var_node);
81 CallGraphEntry* callee_cg_node = LookupGlobalVar(callee);
82 cg_node->AddCalledGlobal(callee_cg_node);
83 }
84 });
85}
86
87const CallGraphEntry* CallGraphNode::operator[](const GlobalVar& gv) const {
88 const_iterator cit = call_graph_.find(gv);
89 ICHECK(cit != call_graph_.end())
90 << "GlobalVar " << gv->name_hint << " not found in the call graph!";
91 return cit->second.get();
92}
93
94CallGraphEntry* CallGraphNode::operator[](const GlobalVar& gv) {
95 const_iterator cit = call_graph_.find(gv);
96 ICHECK(cit != call_graph_.end())
97 << "GlobalVar " << gv->name_hint << " not found in the call graph!";
98 return cit->second.get();
99}
100
101BaseFunc CallGraphNode::GetGlobalFunction(const GlobalVar& var) const {
102 ICHECK(module->ContainGlobalVar(var->name_hint))
103 << "GlobalVar " << var->name_hint << " not found in the current ir module";
104 return module->Lookup(var->name_hint);
105}
106
107CallGraphEntry* CallGraphNode::LookupGlobalVar(const GlobalVar& gv) {
108 ICHECK(gv.defined());
109
110 // This inserts an element to the call graph if it is not there yet.
111 auto& call_graph_node = call_graph_[gv];
112 if (call_graph_node) return call_graph_node.get();
113
114 // Create the node for the inserted entry.
115 call_graph_node = std::make_unique<CallGraphEntry>(gv);
116 return call_graph_node.get();
117}
118
119void CallGraphNode::Print(std::ostream& os) const {
120 // Print the call graph in the topological order.
121 std::vector<CallGraphEntry*> nodes = TopologicalOrder();
122 for (const auto* cgn : nodes) {
123 cgn->Print(os);
124 }
125}
126
127GlobalVar CallGraphNode::RemoveGlobalVarFromModule(CallGraphEntry* cg_node,
128 bool update_call_graph) {
129 ICHECK(cg_node->empty() || (cg_node->IsRecursive() && cg_node->size() == 1))
130 << "Cannot remove global var " << cg_node->GetNameHint()
131 << " from call graph, because it still calls " << cg_node->size()
132 << " other global functions";
133
134 if (update_call_graph) {
135 // Update the call graph by removing all edges that point to the node
136 // `cg_node`.
137 for (auto& it : *this) {
138 it.second->RemoveAllCallTo(cg_node);
139 }
140 }
141 GlobalVar gv = cg_node->GetGlobalVar();
142 call_graph_.erase(gv);
143 // Update the IR module.
144 module->Remove(gv);
145 return gv;
146}
147
148std::vector<CallGraphEntry*> CallGraphNode::GetEntryGlobals() const {
149 std::vector<CallGraphEntry*> ret;
150 // An entry function in Relay is a function that never called by other
151 // functions or only called by itself.
152 for (const auto& it : *this) {
153 if (it.second->GetRefCount() == 0 || it.second->IsRecursiveEntry()) {
154 ret.push_back(it.second.get());
155 }
156 }
157 return ret;
158}
159
160std::vector<CallGraphEntry*> CallGraphNode::TopologicalOrder() const {
161 std::vector<CallGraphEntry*> ret;
162 // Collect all entry nodes.
163 std::vector<CallGraphEntry*> entries = GetEntryGlobals();
164 CallGraphEntry::CallGraphEntrySet visited;
165
166 for (const auto& it : entries) {
167 // Keep tracking the nodes that have been visited.
168 auto topo = it->TopologicalOrder(&visited);
169 // Prepend the collected items. The intermediate nodes that are shared by
170 // multiple entries are guaranteed to be collected when visiting the
171 // previous entries. Therefore, topological order remains.
172 ret.insert(ret.begin(), topo.begin(), topo.end());
173 }
174
175 // Find out the missing global functions if there are any to help debugging.
176 if (ret.size() != module->functions.size()) {
177 for (auto it : module->functions) {
178 if (visited.find((*this)[it.first]) == visited.end()) {
179 LOG(WARNING) << "Missing global:" << it.first->name_hint
180 << " with # refs = " << (*this)[it.first]->GetRefCount();
181 }
182 }
183 LOG(FATAL) << "Expected " << module->functions.size() << " globals, but received "
184 << ret.size();
185 }
186
187 return ret;
188}
189
190// BSF traversal is used to collect the nodes in a CallGraphEntry. The nodes
191// that are visited by previous CallGraphEntry entries can be memoized. This
192// helps us to make sure no entry will be visited multiple times when collecting
193// the nodes for an entire call graph.
194std::vector<CallGraphEntry*> CallGraphEntry::TopologicalOrder(CallGraphEntrySet* visited) const {
195 std::vector<CallGraphEntry*> ret;
196 std::vector<CallGraphEntry*> current_nodes;
197 if (visited->find(this) == visited->end()) {
198 visited->emplace(this);
199 current_nodes.emplace_back(const_cast<CallGraphEntry*>(this));
200 }
201
202 std::vector<CallGraphEntry*> next_nodes;
203 while (!current_nodes.empty()) {
204 for (const auto& node : current_nodes) {
205 ret.push_back(node);
206 // Iterate through the called entries.
207 for (auto git = node->begin(); git != node->end(); ++git) {
208 if (visited->find(git->second) == visited->end()) {
209 next_nodes.push_back(git->second);
210 visited->emplace(git->second);
211 }
212 }
213 }
214 // Update the current level and clean the next level.
215 current_nodes = next_nodes;
216 next_nodes.clear();
217 }
218 return ret;
219}
220
221void CallGraphEntry::CleanCallGraphEntries() {
222 while (!called_globals_.empty()) {
223 // Decrement the reference counter
224 called_globals_.back().second->DecRef();
225 called_globals_.pop_back();
226 }
227}
228
229inline void CallGraphEntry::AddCalledGlobal(CallGraphEntry* cg_node) {
230 called_globals_.emplace_back(global_, cg_node);
231 // Increment the reference to indicate that another call site is found for
232 // the callee in `cg_node`.
233 cg_node->IncRef();
234 // Mark the global function as recursive if it calls itself.
235 if (global_ == cg_node->GetGlobalVar()) {
236 cg_node->is_recursive_ = true;
237 }
238}
239
240// Remove an edge from the current global function to the callee.
241void CallGraphEntry::RemoveCallTo(const GlobalVar& callee) {
242 for (auto it = begin();; ++it) {
243 ICHECK(it != end()) << "Cannot find global function " << callee->name_hint << " to remove!";
244 if (it->second->GetGlobalVar() == callee) {
245 // Only remove one occurrence of the call site.
246 it->second->DecRef();
247 *it = called_globals_.back();
248 called_globals_.pop_back();
249 return;
250 }
251 }
252}
253
254// Remove all edges from the current global function to the callee.
255void CallGraphEntry::RemoveAllCallTo(CallGraphEntry* callee) {
256 for (uint32_t i = 0, e = size(); i != e;) {
257 if (called_globals_[i].second == callee) {
258 callee->DecRef();
259 called_globals_[i] = called_globals_.back();
260 called_globals_.pop_back();
261 --e;
262 } else {
263 ++i;
264 }
265 }
266 // Make sure all references to the callee are removed.
267 ICHECK_EQ(callee->GetRefCount(), 0U)
268 << "All references to " << callee->GetNameHint() << " should have been removed";
269}
270
271void CallGraphEntry::Print(std::ostream& os) const {
272 if (!global_.defined()) {
273 os << "GlobalVar is not defined\n";
274 return;
275 }
276
277 os << "Call graph node: " << global_->name_hint;
278 os << " at: " << this << ", #refs = " << GetRefCount() << "\n";
279
280 for (const auto& it : *this) {
281 os << " call site: <" << it.first->name_hint << "> calls ";
282 os << it.second->GetNameHint() << "\n";
283 }
284 os << "\n";
285}
286
287std::ostream& operator<<(std::ostream& os, const CallGraph& cg) {
288 cg->Print(os);
289 return os;
290}
291
292std::ostream& operator<<(std::ostream& os, const CallGraphEntry& cgn) {
293 cgn.Print(os);
294 return os;
295}
296
297TVM_REGISTER_NODE_TYPE(CallGraphNode);
298
299TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
300 .set_dispatch<CallGraphNode>([](const ObjectRef& ref, ReprPrinter* p) {
301 auto* node = static_cast<const CallGraphNode*>(ref.get());
302 ICHECK(node);
303 p->stream << "CallGraph: \n" << GetRef<CallGraph>(node);
304 });
305
306TVM_REGISTER_GLOBAL("relay.analysis.CallGraph").set_body_typed([](IRModule module) {
307 return CallGraph(module);
308});
309
310TVM_REGISTER_GLOBAL("relay.analysis.PrintCallGraph").set_body_typed([](CallGraph call_graph) {
311 std::stringstream ss;
312 ss << call_graph;
313 return ss.str();
314});
315
316TVM_REGISTER_GLOBAL("relay.analysis.GetModule").set_body_typed([](CallGraph call_graph) {
317 return call_graph->module;
318});
319
320TVM_REGISTER_GLOBAL("relay.analysis.PrintCallGraphGlobalVar")
321 .set_body_typed([](CallGraph call_graph, GlobalVar var) {
322 const auto* entry_node = call_graph[var];
323 std::stringstream ss;
324 ss << *entry_node;
325 return ss.str();
326 });
327
328TVM_REGISTER_GLOBAL("relay.analysis.GetRefCountGlobalVar")
329 .set_body_typed([](CallGraph call_graph, GlobalVar var) {
330 const auto* entry_node = call_graph[var];
331 return static_cast<int>(entry_node->GetRefCount());
332 });
333
334TVM_REGISTER_GLOBAL("relay.analysis.GetGlobalVarCallCount")
335 .set_body_typed([](CallGraph call_graph, GlobalVar var) {
336 const auto* entry_node = call_graph[var];
337 return static_cast<int>(entry_node->size());
338 });
339
340TVM_REGISTER_GLOBAL("relay.analysis.IsRecursive")
341 .set_body_typed([](CallGraph call_graph, GlobalVar var) {
342 const auto* entry_node = call_graph[var];
343 return entry_node->IsRecursive();
344 });
345
346} // namespace relay
347} // namespace tvm
348