1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/analysis/call_graph.h |
22 | * \brief Define data structures for the call graph of a IRModule. It borrows |
23 | * the idea how LLVM constructs CallGraph. |
24 | * |
25 | * https://llvm.org/doxygen/CallGraph_8h_source.html |
26 | */ |
27 | |
28 | #ifndef TVM_RELAY_ANALYSIS_CALL_GRAPH_H_ |
29 | #define TVM_RELAY_ANALYSIS_CALL_GRAPH_H_ |
30 | |
31 | #include <tvm/ir/module.h> |
32 | #include <tvm/relay/expr.h> |
33 | #include <tvm/relay/function.h> |
34 | #include <tvm/runtime/object.h> |
35 | |
36 | #include <memory> |
37 | #include <string> |
38 | #include <unordered_map> |
39 | #include <unordered_set> |
40 | #include <utility> |
41 | #include <vector> |
42 | |
43 | namespace tvm { |
44 | namespace relay { |
45 | |
46 | class CallGraphEntry; |
47 | class CallGraph; |
48 | |
49 | class CallGraphNode : public Object { |
50 | using CallGraphMap = |
51 | std::unordered_map<GlobalVar, std::unique_ptr<CallGraphEntry>, ObjectPtrHash, ObjectPtrEqual>; |
52 | // Create iterator alias for a CallGraphNode object. |
53 | using iterator = CallGraphMap::iterator; |
54 | using const_iterator = CallGraphMap::const_iterator; |
55 | |
56 | public: |
57 | /*! \brief The IR module for creating a CallGraphNode. */ |
58 | IRModule module; |
59 | |
60 | /*! \brief Default constructor. */ |
61 | CallGraphNode() {} |
62 | |
63 | void VisitAttrs(AttrVisitor* v) { v->Visit("module" , &module); } |
64 | |
65 | /*! |
66 | * \brief Print the call graph. |
67 | * |
68 | * \param os The stream for printing. |
69 | */ |
70 | void Print(std::ostream& os) const; |
71 | |
72 | /*! \return The begin iterator. */ |
73 | iterator begin() { return call_graph_.begin(); } |
74 | /*! \return The end iterator. */ |
75 | iterator end() { return call_graph_.end(); } |
76 | /*! \return The begin iterator. */ |
77 | const_iterator begin() const { return call_graph_.begin(); } |
78 | /*! \return The end iterator. */ |
79 | const_iterator end() const { return call_graph_.end(); } |
80 | |
81 | /*! |
82 | * \brief Get an element from the CallGraphNode using a GlobalVar. |
83 | * |
84 | * \param gv The GlobalVar used for indexing. |
85 | * |
86 | * \return The fetched element. |
87 | */ |
88 | const CallGraphEntry* operator[](const GlobalVar& gv) const; |
89 | /*! |
90 | * \brief Get an element from the CallGraphNode using a GlobalVar. |
91 | * |
92 | * \param gv The GlobalVar used for indexing. |
93 | * |
94 | * \return The fetched element. |
95 | */ |
96 | CallGraphEntry* operator[](const GlobalVar& gv); |
97 | /*! |
98 | * \brief Get an element from the CallGraphNode using the global function name. |
99 | * |
100 | * \param gvar_name The global function name used for indexing. |
101 | * |
102 | * \return The fetched element. |
103 | */ |
104 | const CallGraphEntry* operator[](const std::string& gvar_name) const { |
105 | return (*this)[module->GetGlobalVar(gvar_name)]; |
106 | } |
107 | /*! |
108 | * \brief Get an element from the CallGraphNode using the global function name. |
109 | * |
110 | * \param gvar_name The global function name used for indexing. |
111 | * |
112 | * \return The fetched element. |
113 | */ |
114 | CallGraphEntry* operator[](const std::string& gvar_name) { |
115 | return (*this)[module->GetGlobalVar(gvar_name)]; |
116 | } |
117 | |
118 | /*! |
119 | * \brief Get the global function corresponding to the variable. |
120 | * |
121 | * \param var The global variable. |
122 | * |
123 | * \return The found global function. |
124 | */ |
125 | BaseFunc GetGlobalFunction(const GlobalVar& var) const; |
126 | |
127 | /*! |
128 | * \brief Get the entries/root nodes of CallGraphNode. |
129 | * |
130 | * Entry functions are never referenced by other functions. |
131 | * Note these functions can be recursive as well. |
132 | * |
133 | * \return The list of CallGraphEntry that represent entry nodes. |
134 | */ |
135 | std::vector<CallGraphEntry*> GetEntryGlobals() const; |
136 | |
137 | /*! |
138 | * \brief Remove a GlobalVar in a given CallGraphEntry from the current |
139 | * IR module. |
140 | * |
141 | * \param cg_node The CallGraphEntry that contains a global function to be |
142 | * removed. |
143 | * \param update_call_graph Indicate if we will update the CallGraph as well |
144 | * since updating is costly. We are only able to remove a leaf function |
145 | * when update_call_graph is disabled because the edges pointing to |
146 | * functions being removed are not updated. |
147 | * |
148 | * \return The GlobalVar removed from the current module. |
149 | */ |
150 | GlobalVar RemoveGlobalVarFromModule(CallGraphEntry* cg_node, bool update_call_graph = false); |
151 | |
152 | /*! |
153 | * \brief Lookup a GlobalVar for the CallGraphNode. It creates an entry for |
154 | * the GlobalVar if it doesn't exist. |
155 | * |
156 | * \param gv The GlobalVar for query. |
157 | * |
158 | * \return The queried entry. |
159 | */ |
160 | CallGraphEntry* LookupGlobalVar(const GlobalVar& gv); |
161 | |
162 | /*! |
163 | * \brief Get the entries from the CallGraphNode in the topological order. |
164 | * |
165 | * This is useful for various module-level optimizations/analysis. For example, |
166 | * inlining requires the correct order of the functions being processed, i.e. |
167 | * callee should be always handled before callers. |
168 | * |
169 | * \return The list of collected entries that are sorted in the topological order. |
170 | */ |
171 | std::vector<CallGraphEntry*> TopologicalOrder() const; |
172 | |
173 | static constexpr const char* _type_key = "relay.CallGraph" ; |
174 | TVM_DECLARE_FINAL_OBJECT_INFO(CallGraphNode, Object); |
175 | |
176 | private: |
177 | /*! |
178 | * \brief Create a CallGraphEntry for a global function and add it to the |
179 | * CallGraphNode. |
180 | * |
181 | * \param gv The global var. |
182 | * \param func The global function corresponding to `gv`. |
183 | */ |
184 | void AddToCallGraph(const GlobalVar& gv, const Function& func); |
185 | |
186 | /*! \brief A record contains GlobalVar to CallGraphEntry mapping. */ |
187 | CallGraphMap call_graph_; |
188 | |
189 | friend CallGraph; |
190 | }; |
191 | |
192 | /*! |
193 | * \brief The class that represents the call graph of a Relay IR module. It also |
194 | * provides a variety of utility functions for users to query, view, and update |
195 | * a call graph. |
196 | */ |
197 | class CallGraph : public ObjectRef { |
198 | using CallGraphMap = |
199 | std::unordered_map<GlobalVar, std::unique_ptr<CallGraphEntry>, ObjectPtrHash, ObjectPtrEqual>; |
200 | // Create iterator alias for a CallGraph object. |
201 | using iterator = CallGraphMap::iterator; |
202 | using const_iterator = CallGraphMap::const_iterator; |
203 | |
204 | public: |
205 | /*! |
206 | * \brief Construct a CallGraph from a IR module. |
207 | * |
208 | * \param module The IR module |
209 | */ |
210 | explicit CallGraph(IRModule module); |
211 | |
212 | /*! |
213 | * \brief Construct from an object pointer. |
214 | * \param n The object pointer. |
215 | */ |
216 | explicit CallGraph(ObjectPtr<Object> n) : ObjectRef(n) {} |
217 | |
218 | /*! \return The begin iterator. */ |
219 | iterator begin() { |
220 | auto* n = operator->(); |
221 | ICHECK(n); |
222 | return n->begin(); |
223 | } |
224 | /*! \return The end iterator. */ |
225 | iterator end() { |
226 | auto* n = operator->(); |
227 | ICHECK(n); |
228 | return n->end(); |
229 | } |
230 | /*! \return The begin iterator. */ |
231 | const_iterator begin() const { |
232 | const auto* n = operator->(); |
233 | ICHECK(n); |
234 | return n->begin(); |
235 | } |
236 | /*! \return The end iterator. */ |
237 | const_iterator end() const { |
238 | const auto* n = operator->(); |
239 | ICHECK(n); |
240 | return n->end(); |
241 | } |
242 | |
243 | /*! |
244 | * \brief Get an element from the CallGraph using a GlobalVar. |
245 | * |
246 | * \param gv The GlobalVar used for indexing. |
247 | * |
248 | * \return The fetched element. |
249 | */ |
250 | const CallGraphEntry* operator[](const GlobalVar& gv) const { |
251 | const auto* n = operator->(); |
252 | ICHECK(n); |
253 | return (*n)[gv]; |
254 | } |
255 | /*! |
256 | * \brief Get an element from the CallGraph using a GlobalVar. |
257 | * |
258 | * \param gv The GlobalVar used for indexing. |
259 | * |
260 | * \return The fetched element. |
261 | */ |
262 | CallGraphEntry* operator[](const GlobalVar& gv) { |
263 | auto* n = operator->(); |
264 | ICHECK(n); |
265 | return (*n)[gv]; |
266 | } |
267 | /*! |
268 | * \brief Get an element from the CallGraph using the global function name. |
269 | * |
270 | * \param gvar_name The global function name used for indexing. |
271 | * |
272 | * \return The fetched element. |
273 | */ |
274 | const CallGraphEntry* operator[](const std::string& gvar_name) const { |
275 | const auto* n = operator->(); |
276 | ICHECK(n); |
277 | return (*n)[gvar_name]; |
278 | } |
279 | /*! |
280 | * \brief Get an element from the CallGraph using the global function name. |
281 | * |
282 | * \param gvar_name The global function name used for indexing. |
283 | * |
284 | * \return The fetched element. |
285 | */ |
286 | CallGraphEntry* operator[](const std::string& gvar_name) { |
287 | auto* n = operator->(); |
288 | ICHECK(n); |
289 | return (*n)[gvar_name]; |
290 | } |
291 | |
292 | /*! \return mutable pointers to the node. */ |
293 | CallGraphNode* operator->() const { |
294 | auto* ptr = get_mutable(); |
295 | ICHECK(ptr != nullptr); |
296 | return static_cast<CallGraphNode*>(ptr); |
297 | } |
298 | |
299 | private: |
300 | /*! \brief Overload the << operator to print a call graph. */ |
301 | friend std::ostream& operator<<(std::ostream& os, const CallGraph&); |
302 | }; |
303 | |
304 | /*! |
305 | * \brief A node in the call graph. It maintains the edges from a caller to |
306 | * all callees. |
307 | */ |
308 | class CallGraphEntry { |
309 | public: |
310 | using CallGraphEntryPair = std::pair<GlobalVar, CallGraphEntry*>; |
311 | using CallGraphEntryVector = std::vector<CallGraphEntryPair>; |
312 | using CallGraphEntrySet = std::unordered_set<const CallGraphEntry*>; |
313 | // Create iterator alias for a CallGraphEntry object. |
314 | using iterator = std::vector<CallGraphEntryPair>::iterator; |
315 | using const_iterator = std::vector<CallGraphEntryPair>::const_iterator; |
316 | |
317 | /*! |
318 | * \brief Construct from a GlobalVar. |
319 | * |
320 | * \param gv The GlobalVar to create a CallGraphEntry. |
321 | */ |
322 | explicit CallGraphEntry(const GlobalVar& gv) : global_(gv) {} |
323 | /*! |
324 | * \brief Delete copy constructor. |
325 | */ |
326 | CallGraphEntry(const CallGraphEntry&) = delete; |
327 | /*! \brief Delete assignment. */ |
328 | CallGraphEntry& operator=(const CallGraphEntry&) = delete; |
329 | |
330 | /*! \return The begin iterator */ |
331 | iterator begin() { return called_globals_.begin(); } |
332 | /*! \return The end iterator */ |
333 | iterator end() { return called_globals_.end(); } |
334 | /*! \return The const begin iterator */ |
335 | const_iterator begin() const { return called_globals_.begin(); } |
336 | /*! \return The const end iterator */ |
337 | const_iterator end() const { return called_globals_.end(); } |
338 | |
339 | /*! |
340 | * \brief Return if the list of called nodes is empty. |
341 | * |
342 | * \return true if the list is empty. Otherwise, false. |
343 | */ |
344 | bool empty() const { return called_globals_.empty(); } |
345 | |
346 | /*! |
347 | * \brief Return the size of the list that represents the nodes are called by |
348 | * the current node. |
349 | * |
350 | * \return The number of called nodes. |
351 | */ |
352 | uint32_t size() const { return static_cast<uint32_t>(called_globals_.size()); } |
353 | |
354 | /*! |
355 | * \brief Fetch the i-th CallGraphEntry from the list of nodes that are called |
356 | * by the current function. |
357 | * |
358 | * \param i The index. |
359 | * |
360 | * \return The fetched CallGraphEntry. |
361 | */ |
362 | CallGraphEntry* operator[](size_t i) const { |
363 | ICHECK_LT(i, called_globals_.size()) << "Invalid Index" ; |
364 | return called_globals_[i].second; |
365 | } |
366 | |
367 | /*! |
368 | * \brief Print the call graph that is stemmed from the current CallGraphEntry. |
369 | * |
370 | * \param os The stream for printing. |
371 | */ |
372 | void Print(std::ostream& os) const; |
373 | |
374 | /*! |
375 | * \brief Return the number of times the global function is referenced. |
376 | * |
377 | * \return The count. |
378 | */ |
379 | uint32_t GetRefCount() const { return ref_cnt_; } |
380 | |
381 | /*! |
382 | * \brief Return the GlobalVar stored in the current CallGraphEntry. |
383 | * |
384 | * \return The GlobalVar. |
385 | */ |
386 | GlobalVar GetGlobalVar() const { return global_; } |
387 | |
388 | /*! |
389 | * \brief Return the name hint of the GlobalVar stored in the CallGraphEntry. |
390 | * |
391 | * \return The name hint of the global function. |
392 | */ |
393 | std::string GetNameHint() const { return global_->name_hint; } |
394 | |
395 | /*! |
396 | * \brief Return if the global function corresponding to the current |
397 | * CallGraphEntry is a recursive function. |
398 | * |
399 | * \return true if it is recursive. Otherwise, false. |
400 | */ |
401 | bool IsRecursive() const { return is_recursive_; } |
402 | |
403 | /*! |
404 | * \brief Return if the global function corresponding to the current |
405 | * CallGraphEntry is both a recursive function and an entry function. This type |
406 | * of function only has one reference which is called by itself. |
407 | * |
408 | * \return true if it is both a recursive function and an entry. Otherwise, false. |
409 | */ |
410 | bool IsRecursiveEntry() const { return GetRefCount() == 1 && IsRecursive(); } |
411 | |
412 | /*! |
413 | * \brief Return the topological order of the CallGraphEntry. |
414 | * |
415 | * \param visited A set of CallGraphEntry objects that have been visited. |
416 | * |
417 | * \return The list of CallGraphEntry that is represented in topological order. |
418 | */ |
419 | std::vector<CallGraphEntry*> TopologicalOrder( |
420 | CallGraphEntrySet* visited = new CallGraphEntrySet()) const; |
421 | |
422 | /*! |
423 | * \brief Remove all edges from the current CallGraphEntry to any global |
424 | * function it calls. |
425 | */ |
426 | void CleanCallGraphEntries(); |
427 | |
428 | /*! |
429 | * \brief Add a node to the list of nodes that are being called by the current |
430 | * global function. |
431 | * |
432 | * \param cg_node The CallGraphEntry that will be added to the call list. |
433 | */ |
434 | void AddCalledGlobal(CallGraphEntry* cg_node); |
435 | |
436 | /*! |
437 | * \brief Remove a call edge to the global function from the current |
438 | * function. |
439 | * |
440 | * \param callee The function that is being called. |
441 | */ |
442 | void RemoveCallTo(const GlobalVar& callee); |
443 | |
444 | /*! |
445 | * \brief Remove all the edges that represent that calls to the global function |
446 | * stored in a given CallGraphEntry. |
447 | * |
448 | * \param callee The function that is being called. |
449 | */ |
450 | void RemoveAllCallTo(CallGraphEntry* callee); |
451 | |
452 | private: |
453 | /*! \brief Decrement the reference counter by 1. */ |
454 | void DecRef() { |
455 | ICHECK_GT(ref_cnt_, 0); |
456 | --ref_cnt_; |
457 | } |
458 | /*! \brief Increment the reference counter by 1. */ |
459 | void IncRef() { ++ref_cnt_; } |
460 | |
461 | /*! |
462 | * \brief Mark if the global function stored in the CallGraphEntry is |
463 | * recursive function. |
464 | */ |
465 | bool is_recursive_{false}; |
466 | /*! \brief Count the number of times the global function is referenced. */ |
467 | uint32_t ref_cnt_{0}; |
468 | /*! \brief The GlobalVar stored in the current CallGraphEntry. */ |
469 | GlobalVar global_; |
470 | /*! \brief The list of entries called by the current CallGraphEntry. */ |
471 | CallGraphEntryVector called_globals_; |
472 | |
473 | friend class CallGraph; |
474 | /*! \brief Overload the << operator to print a call graph node. */ |
475 | friend std::ostream& operator<<(std::ostream& os, const CallGraphEntry&); |
476 | }; |
477 | |
478 | } // namespace relay |
479 | } // namespace tvm |
480 | #endif // TVM_RELAY_ANALYSIS_CALL_GRAPH_H_ |
481 | |