1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/analysis/call_graph.h
22 * \brief Define data structures for the call graph of a IRModule. It borrows
23 * the idea how LLVM constructs CallGraph.
24 *
25 * https://llvm.org/doxygen/CallGraph_8h_source.html
26 */
27
28#ifndef TVM_RELAY_ANALYSIS_CALL_GRAPH_H_
29#define TVM_RELAY_ANALYSIS_CALL_GRAPH_H_
30
31#include <tvm/ir/module.h>
32#include <tvm/relay/expr.h>
33#include <tvm/relay/function.h>
34#include <tvm/runtime/object.h>
35
36#include <memory>
37#include <string>
38#include <unordered_map>
39#include <unordered_set>
40#include <utility>
41#include <vector>
42
43namespace tvm {
44namespace relay {
45
46class CallGraphEntry;
47class CallGraph;
48
49class CallGraphNode : public Object {
50 using CallGraphMap =
51 std::unordered_map<GlobalVar, std::unique_ptr<CallGraphEntry>, ObjectPtrHash, ObjectPtrEqual>;
52 // Create iterator alias for a CallGraphNode object.
53 using iterator = CallGraphMap::iterator;
54 using const_iterator = CallGraphMap::const_iterator;
55
56 public:
57 /*! \brief The IR module for creating a CallGraphNode. */
58 IRModule module;
59
60 /*! \brief Default constructor. */
61 CallGraphNode() {}
62
63 void VisitAttrs(AttrVisitor* v) { v->Visit("module", &module); }
64
65 /*!
66 * \brief Print the call graph.
67 *
68 * \param os The stream for printing.
69 */
70 void Print(std::ostream& os) const;
71
72 /*! \return The begin iterator. */
73 iterator begin() { return call_graph_.begin(); }
74 /*! \return The end iterator. */
75 iterator end() { return call_graph_.end(); }
76 /*! \return The begin iterator. */
77 const_iterator begin() const { return call_graph_.begin(); }
78 /*! \return The end iterator. */
79 const_iterator end() const { return call_graph_.end(); }
80
81 /*!
82 * \brief Get an element from the CallGraphNode using a GlobalVar.
83 *
84 * \param gv The GlobalVar used for indexing.
85 *
86 * \return The fetched element.
87 */
88 const CallGraphEntry* operator[](const GlobalVar& gv) const;
89 /*!
90 * \brief Get an element from the CallGraphNode using a GlobalVar.
91 *
92 * \param gv The GlobalVar used for indexing.
93 *
94 * \return The fetched element.
95 */
96 CallGraphEntry* operator[](const GlobalVar& gv);
97 /*!
98 * \brief Get an element from the CallGraphNode using the global function name.
99 *
100 * \param gvar_name The global function name used for indexing.
101 *
102 * \return The fetched element.
103 */
104 const CallGraphEntry* operator[](const std::string& gvar_name) const {
105 return (*this)[module->GetGlobalVar(gvar_name)];
106 }
107 /*!
108 * \brief Get an element from the CallGraphNode using the global function name.
109 *
110 * \param gvar_name The global function name used for indexing.
111 *
112 * \return The fetched element.
113 */
114 CallGraphEntry* operator[](const std::string& gvar_name) {
115 return (*this)[module->GetGlobalVar(gvar_name)];
116 }
117
118 /*!
119 * \brief Get the global function corresponding to the variable.
120 *
121 * \param var The global variable.
122 *
123 * \return The found global function.
124 */
125 BaseFunc GetGlobalFunction(const GlobalVar& var) const;
126
127 /*!
128 * \brief Get the entries/root nodes of CallGraphNode.
129 *
130 * Entry functions are never referenced by other functions.
131 * Note these functions can be recursive as well.
132 *
133 * \return The list of CallGraphEntry that represent entry nodes.
134 */
135 std::vector<CallGraphEntry*> GetEntryGlobals() const;
136
137 /*!
138 * \brief Remove a GlobalVar in a given CallGraphEntry from the current
139 * IR module.
140 *
141 * \param cg_node The CallGraphEntry that contains a global function to be
142 * removed.
143 * \param update_call_graph Indicate if we will update the CallGraph as well
144 * since updating is costly. We are only able to remove a leaf function
145 * when update_call_graph is disabled because the edges pointing to
146 * functions being removed are not updated.
147 *
148 * \return The GlobalVar removed from the current module.
149 */
150 GlobalVar RemoveGlobalVarFromModule(CallGraphEntry* cg_node, bool update_call_graph = false);
151
152 /*!
153 * \brief Lookup a GlobalVar for the CallGraphNode. It creates an entry for
154 * the GlobalVar if it doesn't exist.
155 *
156 * \param gv The GlobalVar for query.
157 *
158 * \return The queried entry.
159 */
160 CallGraphEntry* LookupGlobalVar(const GlobalVar& gv);
161
162 /*!
163 * \brief Get the entries from the CallGraphNode in the topological order.
164 *
165 * This is useful for various module-level optimizations/analysis. For example,
166 * inlining requires the correct order of the functions being processed, i.e.
167 * callee should be always handled before callers.
168 *
169 * \return The list of collected entries that are sorted in the topological order.
170 */
171 std::vector<CallGraphEntry*> TopologicalOrder() const;
172
173 static constexpr const char* _type_key = "relay.CallGraph";
174 TVM_DECLARE_FINAL_OBJECT_INFO(CallGraphNode, Object);
175
176 private:
177 /*!
178 * \brief Create a CallGraphEntry for a global function and add it to the
179 * CallGraphNode.
180 *
181 * \param gv The global var.
182 * \param func The global function corresponding to `gv`.
183 */
184 void AddToCallGraph(const GlobalVar& gv, const Function& func);
185
186 /*! \brief A record contains GlobalVar to CallGraphEntry mapping. */
187 CallGraphMap call_graph_;
188
189 friend CallGraph;
190};
191
192/*!
193 * \brief The class that represents the call graph of a Relay IR module. It also
194 * provides a variety of utility functions for users to query, view, and update
195 * a call graph.
196 */
197class CallGraph : public ObjectRef {
198 using CallGraphMap =
199 std::unordered_map<GlobalVar, std::unique_ptr<CallGraphEntry>, ObjectPtrHash, ObjectPtrEqual>;
200 // Create iterator alias for a CallGraph object.
201 using iterator = CallGraphMap::iterator;
202 using const_iterator = CallGraphMap::const_iterator;
203
204 public:
205 /*!
206 * \brief Construct a CallGraph from a IR module.
207 *
208 * \param module The IR module
209 */
210 explicit CallGraph(IRModule module);
211
212 /*!
213 * \brief Construct from an object pointer.
214 * \param n The object pointer.
215 */
216 explicit CallGraph(ObjectPtr<Object> n) : ObjectRef(n) {}
217
218 /*! \return The begin iterator. */
219 iterator begin() {
220 auto* n = operator->();
221 ICHECK(n);
222 return n->begin();
223 }
224 /*! \return The end iterator. */
225 iterator end() {
226 auto* n = operator->();
227 ICHECK(n);
228 return n->end();
229 }
230 /*! \return The begin iterator. */
231 const_iterator begin() const {
232 const auto* n = operator->();
233 ICHECK(n);
234 return n->begin();
235 }
236 /*! \return The end iterator. */
237 const_iterator end() const {
238 const auto* n = operator->();
239 ICHECK(n);
240 return n->end();
241 }
242
243 /*!
244 * \brief Get an element from the CallGraph using a GlobalVar.
245 *
246 * \param gv The GlobalVar used for indexing.
247 *
248 * \return The fetched element.
249 */
250 const CallGraphEntry* operator[](const GlobalVar& gv) const {
251 const auto* n = operator->();
252 ICHECK(n);
253 return (*n)[gv];
254 }
255 /*!
256 * \brief Get an element from the CallGraph using a GlobalVar.
257 *
258 * \param gv The GlobalVar used for indexing.
259 *
260 * \return The fetched element.
261 */
262 CallGraphEntry* operator[](const GlobalVar& gv) {
263 auto* n = operator->();
264 ICHECK(n);
265 return (*n)[gv];
266 }
267 /*!
268 * \brief Get an element from the CallGraph using the global function name.
269 *
270 * \param gvar_name The global function name used for indexing.
271 *
272 * \return The fetched element.
273 */
274 const CallGraphEntry* operator[](const std::string& gvar_name) const {
275 const auto* n = operator->();
276 ICHECK(n);
277 return (*n)[gvar_name];
278 }
279 /*!
280 * \brief Get an element from the CallGraph using the global function name.
281 *
282 * \param gvar_name The global function name used for indexing.
283 *
284 * \return The fetched element.
285 */
286 CallGraphEntry* operator[](const std::string& gvar_name) {
287 auto* n = operator->();
288 ICHECK(n);
289 return (*n)[gvar_name];
290 }
291
292 /*! \return mutable pointers to the node. */
293 CallGraphNode* operator->() const {
294 auto* ptr = get_mutable();
295 ICHECK(ptr != nullptr);
296 return static_cast<CallGraphNode*>(ptr);
297 }
298
299 private:
300 /*! \brief Overload the << operator to print a call graph. */
301 friend std::ostream& operator<<(std::ostream& os, const CallGraph&);
302};
303
304/*!
305 * \brief A node in the call graph. It maintains the edges from a caller to
306 * all callees.
307 */
308class CallGraphEntry {
309 public:
310 using CallGraphEntryPair = std::pair<GlobalVar, CallGraphEntry*>;
311 using CallGraphEntryVector = std::vector<CallGraphEntryPair>;
312 using CallGraphEntrySet = std::unordered_set<const CallGraphEntry*>;
313 // Create iterator alias for a CallGraphEntry object.
314 using iterator = std::vector<CallGraphEntryPair>::iterator;
315 using const_iterator = std::vector<CallGraphEntryPair>::const_iterator;
316
317 /*!
318 * \brief Construct from a GlobalVar.
319 *
320 * \param gv The GlobalVar to create a CallGraphEntry.
321 */
322 explicit CallGraphEntry(const GlobalVar& gv) : global_(gv) {}
323 /*!
324 * \brief Delete copy constructor.
325 */
326 CallGraphEntry(const CallGraphEntry&) = delete;
327 /*! \brief Delete assignment. */
328 CallGraphEntry& operator=(const CallGraphEntry&) = delete;
329
330 /*! \return The begin iterator */
331 iterator begin() { return called_globals_.begin(); }
332 /*! \return The end iterator */
333 iterator end() { return called_globals_.end(); }
334 /*! \return The const begin iterator */
335 const_iterator begin() const { return called_globals_.begin(); }
336 /*! \return The const end iterator */
337 const_iterator end() const { return called_globals_.end(); }
338
339 /*!
340 * \brief Return if the list of called nodes is empty.
341 *
342 * \return true if the list is empty. Otherwise, false.
343 */
344 bool empty() const { return called_globals_.empty(); }
345
346 /*!
347 * \brief Return the size of the list that represents the nodes are called by
348 * the current node.
349 *
350 * \return The number of called nodes.
351 */
352 uint32_t size() const { return static_cast<uint32_t>(called_globals_.size()); }
353
354 /*!
355 * \brief Fetch the i-th CallGraphEntry from the list of nodes that are called
356 * by the current function.
357 *
358 * \param i The index.
359 *
360 * \return The fetched CallGraphEntry.
361 */
362 CallGraphEntry* operator[](size_t i) const {
363 ICHECK_LT(i, called_globals_.size()) << "Invalid Index";
364 return called_globals_[i].second;
365 }
366
367 /*!
368 * \brief Print the call graph that is stemmed from the current CallGraphEntry.
369 *
370 * \param os The stream for printing.
371 */
372 void Print(std::ostream& os) const;
373
374 /*!
375 * \brief Return the number of times the global function is referenced.
376 *
377 * \return The count.
378 */
379 uint32_t GetRefCount() const { return ref_cnt_; }
380
381 /*!
382 * \brief Return the GlobalVar stored in the current CallGraphEntry.
383 *
384 * \return The GlobalVar.
385 */
386 GlobalVar GetGlobalVar() const { return global_; }
387
388 /*!
389 * \brief Return the name hint of the GlobalVar stored in the CallGraphEntry.
390 *
391 * \return The name hint of the global function.
392 */
393 std::string GetNameHint() const { return global_->name_hint; }
394
395 /*!
396 * \brief Return if the global function corresponding to the current
397 * CallGraphEntry is a recursive function.
398 *
399 * \return true if it is recursive. Otherwise, false.
400 */
401 bool IsRecursive() const { return is_recursive_; }
402
403 /*!
404 * \brief Return if the global function corresponding to the current
405 * CallGraphEntry is both a recursive function and an entry function. This type
406 * of function only has one reference which is called by itself.
407 *
408 * \return true if it is both a recursive function and an entry. Otherwise, false.
409 */
410 bool IsRecursiveEntry() const { return GetRefCount() == 1 && IsRecursive(); }
411
412 /*!
413 * \brief Return the topological order of the CallGraphEntry.
414 *
415 * \param visited A set of CallGraphEntry objects that have been visited.
416 *
417 * \return The list of CallGraphEntry that is represented in topological order.
418 */
419 std::vector<CallGraphEntry*> TopologicalOrder(
420 CallGraphEntrySet* visited = new CallGraphEntrySet()) const;
421
422 /*!
423 * \brief Remove all edges from the current CallGraphEntry to any global
424 * function it calls.
425 */
426 void CleanCallGraphEntries();
427
428 /*!
429 * \brief Add a node to the list of nodes that are being called by the current
430 * global function.
431 *
432 * \param cg_node The CallGraphEntry that will be added to the call list.
433 */
434 void AddCalledGlobal(CallGraphEntry* cg_node);
435
436 /*!
437 * \brief Remove a call edge to the global function from the current
438 * function.
439 *
440 * \param callee The function that is being called.
441 */
442 void RemoveCallTo(const GlobalVar& callee);
443
444 /*!
445 * \brief Remove all the edges that represent that calls to the global function
446 * stored in a given CallGraphEntry.
447 *
448 * \param callee The function that is being called.
449 */
450 void RemoveAllCallTo(CallGraphEntry* callee);
451
452 private:
453 /*! \brief Decrement the reference counter by 1. */
454 void DecRef() {
455 ICHECK_GT(ref_cnt_, 0);
456 --ref_cnt_;
457 }
458 /*! \brief Increment the reference counter by 1. */
459 void IncRef() { ++ref_cnt_; }
460
461 /*!
462 * \brief Mark if the global function stored in the CallGraphEntry is
463 * recursive function.
464 */
465 bool is_recursive_{false};
466 /*! \brief Count the number of times the global function is referenced. */
467 uint32_t ref_cnt_{0};
468 /*! \brief The GlobalVar stored in the current CallGraphEntry. */
469 GlobalVar global_;
470 /*! \brief The list of entries called by the current CallGraphEntry. */
471 CallGraphEntryVector called_globals_;
472
473 friend class CallGraph;
474 /*! \brief Overload the << operator to print a call graph node. */
475 friend std::ostream& operator<<(std::ostream& os, const CallGraphEntry&);
476};
477
478} // namespace relay
479} // namespace tvm
480#endif // TVM_RELAY_ANALYSIS_CALL_GRAPH_H_
481