1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/analysis/call_graph.cc |
22 | * \brief Implementation of APIs to handle the call graph of a Relay module. |
23 | */ |
24 | |
25 | #include "call_graph.h" |
26 | |
27 | #include <tvm/relay/attrs/annotation.h> |
28 | #include <tvm/relay/expr_functor.h> |
29 | #include <tvm/runtime/object.h> |
30 | |
31 | #include <algorithm> |
32 | #include <memory> |
33 | #include <sstream> |
34 | #include <unordered_set> |
35 | #include <vector> |
36 | |
37 | #include "../op/call/call.h" |
38 | |
39 | namespace tvm { |
40 | namespace relay { |
41 | |
42 | CallGraph::CallGraph(IRModule module) { |
43 | auto n = make_object<CallGraphNode>(); |
44 | n->module = std::move(module); |
45 | auto gvar_funcs = n->module->functions; |
46 | for (const auto& it : gvar_funcs) { |
47 | if (const auto* fn = it.second.as<FunctionNode>()) { |
48 | auto func = GetRef<Function>(fn); |
49 | // Add the global function to gradually build up the call graph. |
50 | n->AddToCallGraph(it.first, func); |
51 | } |
52 | } |
53 | data_ = std::move(n); |
54 | } |
55 | |
56 | void CallGraphNode::AddToCallGraph(const GlobalVar& gv, const Function& func) { |
57 | ICHECK(func.defined() && gv.defined()); |
58 | // Add the current global function as an entry to the call grpah. |
59 | CallGraphEntry* cg_node = LookupGlobalVar(gv); |
60 | |
61 | // Only GlobalVar nodes need to be handled in a function. It indicates that |
62 | // the global function of a callee is called by the function that is being |
63 | // processed. An edge will be added from the current global function, cg_node, |
64 | // to the node that contains the found callee GlobalVarNode. |
65 | // |
66 | // This is the major overhead for constructing a call graph because the |
67 | // post-order visitor will visit each AST node of the current function to |
68 | // figure out the dependencies between functions. |
69 | PostOrderVisit(func, [&](const Expr& expr) { |
70 | // TODO(mbs): Cleanup shapes functions. |
71 | if (const auto* call_node = expr.as<CallNode>()) { |
72 | CallLoweredProps props = GetCallLoweredProps(call_node); |
73 | if (props.lowered_func.defined() && props.attrs.metadata.count("prim_shape_fn_var" )) { |
74 | // We are implicitly calling the shape function *in addition to* the call target. |
75 | CallGraphEntry* callee_cg_node = |
76 | LookupGlobalVar(Downcast<GlobalVar>(props.attrs.metadata["prim_shape_fn_var" ])); |
77 | cg_node->AddCalledGlobal(callee_cg_node); |
78 | } |
79 | } else if (const auto* global_var_node = expr.as<GlobalVarNode>()) { |
80 | auto callee = GetRef<GlobalVar>(global_var_node); |
81 | CallGraphEntry* callee_cg_node = LookupGlobalVar(callee); |
82 | cg_node->AddCalledGlobal(callee_cg_node); |
83 | } |
84 | }); |
85 | } |
86 | |
87 | const CallGraphEntry* CallGraphNode::operator[](const GlobalVar& gv) const { |
88 | const_iterator cit = call_graph_.find(gv); |
89 | ICHECK(cit != call_graph_.end()) |
90 | << "GlobalVar " << gv->name_hint << " not found in the call graph!" ; |
91 | return cit->second.get(); |
92 | } |
93 | |
94 | CallGraphEntry* CallGraphNode::operator[](const GlobalVar& gv) { |
95 | const_iterator cit = call_graph_.find(gv); |
96 | ICHECK(cit != call_graph_.end()) |
97 | << "GlobalVar " << gv->name_hint << " not found in the call graph!" ; |
98 | return cit->second.get(); |
99 | } |
100 | |
101 | BaseFunc CallGraphNode::GetGlobalFunction(const GlobalVar& var) const { |
102 | ICHECK(module->ContainGlobalVar(var->name_hint)) |
103 | << "GlobalVar " << var->name_hint << " not found in the current ir module" ; |
104 | return module->Lookup(var->name_hint); |
105 | } |
106 | |
107 | CallGraphEntry* CallGraphNode::LookupGlobalVar(const GlobalVar& gv) { |
108 | ICHECK(gv.defined()); |
109 | |
110 | // This inserts an element to the call graph if it is not there yet. |
111 | auto& call_graph_node = call_graph_[gv]; |
112 | if (call_graph_node) return call_graph_node.get(); |
113 | |
114 | // Create the node for the inserted entry. |
115 | call_graph_node = std::make_unique<CallGraphEntry>(gv); |
116 | return call_graph_node.get(); |
117 | } |
118 | |
119 | void CallGraphNode::Print(std::ostream& os) const { |
120 | // Print the call graph in the topological order. |
121 | std::vector<CallGraphEntry*> nodes = TopologicalOrder(); |
122 | for (const auto* cgn : nodes) { |
123 | cgn->Print(os); |
124 | } |
125 | } |
126 | |
127 | GlobalVar CallGraphNode::RemoveGlobalVarFromModule(CallGraphEntry* cg_node, |
128 | bool update_call_graph) { |
129 | ICHECK(cg_node->empty() || (cg_node->IsRecursive() && cg_node->size() == 1)) |
130 | << "Cannot remove global var " << cg_node->GetNameHint() |
131 | << " from call graph, because it still calls " << cg_node->size() |
132 | << " other global functions" ; |
133 | |
134 | if (update_call_graph) { |
135 | // Update the call graph by removing all edges that point to the node |
136 | // `cg_node`. |
137 | for (auto& it : *this) { |
138 | it.second->RemoveAllCallTo(cg_node); |
139 | } |
140 | } |
141 | GlobalVar gv = cg_node->GetGlobalVar(); |
142 | call_graph_.erase(gv); |
143 | // Update the IR module. |
144 | module->Remove(gv); |
145 | return gv; |
146 | } |
147 | |
148 | std::vector<CallGraphEntry*> CallGraphNode::GetEntryGlobals() const { |
149 | std::vector<CallGraphEntry*> ret; |
150 | // An entry function in Relay is a function that never called by other |
151 | // functions or only called by itself. |
152 | for (const auto& it : *this) { |
153 | if (it.second->GetRefCount() == 0 || it.second->IsRecursiveEntry()) { |
154 | ret.push_back(it.second.get()); |
155 | } |
156 | } |
157 | return ret; |
158 | } |
159 | |
160 | std::vector<CallGraphEntry*> CallGraphNode::TopologicalOrder() const { |
161 | std::vector<CallGraphEntry*> ret; |
162 | // Collect all entry nodes. |
163 | std::vector<CallGraphEntry*> entries = GetEntryGlobals(); |
164 | CallGraphEntry::CallGraphEntrySet visited; |
165 | |
166 | for (const auto& it : entries) { |
167 | // Keep tracking the nodes that have been visited. |
168 | auto topo = it->TopologicalOrder(&visited); |
169 | // Prepend the collected items. The intermediate nodes that are shared by |
170 | // multiple entries are guaranteed to be collected when visiting the |
171 | // previous entries. Therefore, topological order remains. |
172 | ret.insert(ret.begin(), topo.begin(), topo.end()); |
173 | } |
174 | |
175 | // Find out the missing global functions if there are any to help debugging. |
176 | if (ret.size() != module->functions.size()) { |
177 | for (auto it : module->functions) { |
178 | if (visited.find((*this)[it.first]) == visited.end()) { |
179 | LOG(WARNING) << "Missing global:" << it.first->name_hint |
180 | << " with # refs = " << (*this)[it.first]->GetRefCount(); |
181 | } |
182 | } |
183 | LOG(FATAL) << "Expected " << module->functions.size() << " globals, but received " |
184 | << ret.size(); |
185 | } |
186 | |
187 | return ret; |
188 | } |
189 | |
190 | // BSF traversal is used to collect the nodes in a CallGraphEntry. The nodes |
191 | // that are visited by previous CallGraphEntry entries can be memoized. This |
192 | // helps us to make sure no entry will be visited multiple times when collecting |
193 | // the nodes for an entire call graph. |
194 | std::vector<CallGraphEntry*> CallGraphEntry::TopologicalOrder(CallGraphEntrySet* visited) const { |
195 | std::vector<CallGraphEntry*> ret; |
196 | std::vector<CallGraphEntry*> current_nodes; |
197 | if (visited->find(this) == visited->end()) { |
198 | visited->emplace(this); |
199 | current_nodes.emplace_back(const_cast<CallGraphEntry*>(this)); |
200 | } |
201 | |
202 | std::vector<CallGraphEntry*> next_nodes; |
203 | while (!current_nodes.empty()) { |
204 | for (const auto& node : current_nodes) { |
205 | ret.push_back(node); |
206 | // Iterate through the called entries. |
207 | for (auto git = node->begin(); git != node->end(); ++git) { |
208 | if (visited->find(git->second) == visited->end()) { |
209 | next_nodes.push_back(git->second); |
210 | visited->emplace(git->second); |
211 | } |
212 | } |
213 | } |
214 | // Update the current level and clean the next level. |
215 | current_nodes = next_nodes; |
216 | next_nodes.clear(); |
217 | } |
218 | return ret; |
219 | } |
220 | |
221 | void CallGraphEntry::CleanCallGraphEntries() { |
222 | while (!called_globals_.empty()) { |
223 | // Decrement the reference counter |
224 | called_globals_.back().second->DecRef(); |
225 | called_globals_.pop_back(); |
226 | } |
227 | } |
228 | |
229 | inline void CallGraphEntry::AddCalledGlobal(CallGraphEntry* cg_node) { |
230 | called_globals_.emplace_back(global_, cg_node); |
231 | // Increment the reference to indicate that another call site is found for |
232 | // the callee in `cg_node`. |
233 | cg_node->IncRef(); |
234 | // Mark the global function as recursive if it calls itself. |
235 | if (global_ == cg_node->GetGlobalVar()) { |
236 | cg_node->is_recursive_ = true; |
237 | } |
238 | } |
239 | |
240 | // Remove an edge from the current global function to the callee. |
241 | void CallGraphEntry::RemoveCallTo(const GlobalVar& callee) { |
242 | for (auto it = begin();; ++it) { |
243 | ICHECK(it != end()) << "Cannot find global function " << callee->name_hint << " to remove!" ; |
244 | if (it->second->GetGlobalVar() == callee) { |
245 | // Only remove one occurrence of the call site. |
246 | it->second->DecRef(); |
247 | *it = called_globals_.back(); |
248 | called_globals_.pop_back(); |
249 | return; |
250 | } |
251 | } |
252 | } |
253 | |
254 | // Remove all edges from the current global function to the callee. |
255 | void CallGraphEntry::RemoveAllCallTo(CallGraphEntry* callee) { |
256 | for (uint32_t i = 0, e = size(); i != e;) { |
257 | if (called_globals_[i].second == callee) { |
258 | callee->DecRef(); |
259 | called_globals_[i] = called_globals_.back(); |
260 | called_globals_.pop_back(); |
261 | --e; |
262 | } else { |
263 | ++i; |
264 | } |
265 | } |
266 | // Make sure all references to the callee are removed. |
267 | ICHECK_EQ(callee->GetRefCount(), 0U) |
268 | << "All references to " << callee->GetNameHint() << " should have been removed" ; |
269 | } |
270 | |
271 | void CallGraphEntry::Print(std::ostream& os) const { |
272 | if (!global_.defined()) { |
273 | os << "GlobalVar is not defined\n" ; |
274 | return; |
275 | } |
276 | |
277 | os << "Call graph node: " << global_->name_hint; |
278 | os << " at: " << this << ", #refs = " << GetRefCount() << "\n" ; |
279 | |
280 | for (const auto& it : *this) { |
281 | os << " call site: <" << it.first->name_hint << "> calls " ; |
282 | os << it.second->GetNameHint() << "\n" ; |
283 | } |
284 | os << "\n" ; |
285 | } |
286 | |
287 | std::ostream& operator<<(std::ostream& os, const CallGraph& cg) { |
288 | cg->Print(os); |
289 | return os; |
290 | } |
291 | |
292 | std::ostream& operator<<(std::ostream& os, const CallGraphEntry& cgn) { |
293 | cgn.Print(os); |
294 | return os; |
295 | } |
296 | |
297 | TVM_REGISTER_NODE_TYPE(CallGraphNode); |
298 | |
299 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
300 | .set_dispatch<CallGraphNode>([](const ObjectRef& ref, ReprPrinter* p) { |
301 | auto* node = static_cast<const CallGraphNode*>(ref.get()); |
302 | ICHECK(node); |
303 | p->stream << "CallGraph: \n" << GetRef<CallGraph>(node); |
304 | }); |
305 | |
306 | TVM_REGISTER_GLOBAL("relay.analysis.CallGraph" ).set_body_typed([](IRModule module) { |
307 | return CallGraph(module); |
308 | }); |
309 | |
310 | TVM_REGISTER_GLOBAL("relay.analysis.PrintCallGraph" ).set_body_typed([](CallGraph call_graph) { |
311 | std::stringstream ss; |
312 | ss << call_graph; |
313 | return ss.str(); |
314 | }); |
315 | |
316 | TVM_REGISTER_GLOBAL("relay.analysis.GetModule" ).set_body_typed([](CallGraph call_graph) { |
317 | return call_graph->module; |
318 | }); |
319 | |
320 | TVM_REGISTER_GLOBAL("relay.analysis.PrintCallGraphGlobalVar" ) |
321 | .set_body_typed([](CallGraph call_graph, GlobalVar var) { |
322 | const auto* entry_node = call_graph[var]; |
323 | std::stringstream ss; |
324 | ss << *entry_node; |
325 | return ss.str(); |
326 | }); |
327 | |
328 | TVM_REGISTER_GLOBAL("relay.analysis.GetRefCountGlobalVar" ) |
329 | .set_body_typed([](CallGraph call_graph, GlobalVar var) { |
330 | const auto* entry_node = call_graph[var]; |
331 | return static_cast<int>(entry_node->GetRefCount()); |
332 | }); |
333 | |
334 | TVM_REGISTER_GLOBAL("relay.analysis.GetGlobalVarCallCount" ) |
335 | .set_body_typed([](CallGraph call_graph, GlobalVar var) { |
336 | const auto* entry_node = call_graph[var]; |
337 | return static_cast<int>(entry_node->size()); |
338 | }); |
339 | |
340 | TVM_REGISTER_GLOBAL("relay.analysis.IsRecursive" ) |
341 | .set_body_typed([](CallGraph call_graph, GlobalVar var) { |
342 | const auto* entry_node = call_graph[var]; |
343 | return entry_node->IsRecursive(); |
344 | }); |
345 | |
346 | } // namespace relay |
347 | } // namespace tvm |
348 | |