1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * |
14 | * software distributed under the License is distributed on an |
15 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
16 | * KIND, either express or implied. See the License for the |
17 | * specific language governing permissions and limitations |
18 | * under the License. |
19 | */ |
20 | |
21 | /*! |
22 | * \file tvm/relay/backend/vm/remove_unused_funcs.cc |
23 | * \brief Remove unused global relay functions in a relay module. |
24 | */ |
25 | |
26 | #include <tvm/relay/analysis.h> |
27 | #include <tvm/relay/attrs/annotation.h> |
28 | #include <tvm/relay/expr.h> |
29 | #include <tvm/relay/expr_functor.h> |
30 | #include <tvm/relay/transform.h> |
31 | #include <tvm/runtime/logging.h> |
32 | |
33 | #include <iostream> |
34 | #include <unordered_set> |
35 | #include <vector> |
36 | |
37 | #include "../../op/call/call.h" |
38 | |
39 | namespace tvm { |
40 | namespace relay { |
41 | namespace vm { |
42 | |
43 | /** |
44 | * \brief Detects all the functions that can be possibly called by entry function. |
45 | */ |
46 | struct CallTracer : ExprVisitor { |
47 | IRModule module_; |
48 | |
49 | // Record the names of all encountered functions |
50 | std::unordered_set<std::string> called_funcs_; |
51 | |
52 | // Record the expressions that are being visited |
53 | std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> visiting_; |
54 | |
55 | explicit CallTracer(const IRModule& module) : module_{module}, called_funcs_{}, visiting_{} {} |
56 | |
57 | void VisitExpr_(const GlobalVarNode* op) final { |
58 | called_funcs_.insert(op->name_hint); |
59 | auto func = module_->Lookup(op->name_hint); |
60 | if (const auto* function_node = func.as<FunctionNode>()) { |
61 | VisitExpr(GetRef<Function>(function_node)); |
62 | } |
63 | // else: Don't visit PrimFuncs -- we don't need to collect any tir.Calls therein. |
64 | } |
65 | |
66 | void VisitExpr_(const CallNode* call_node) final { |
67 | // TODO(mbs): Cleanup shape functions. |
68 | CallLoweredProps props = GetCallLoweredProps(call_node); |
69 | if (props.lowered_func.defined() && props.attrs.metadata.count("prim_shape_fn_var" )) { |
70 | auto callee = Downcast<GlobalVar>(props.attrs.metadata["prim_shape_fn_var" ]); |
71 | // We are implicitly calling the shape function *in addition to* the callee. |
72 | called_funcs_.insert(callee->name_hint); |
73 | } |
74 | ExprVisitor::VisitExpr_(call_node); |
75 | } |
76 | |
77 | void VisitExpr_(const FunctionNode* func_node) final { |
78 | auto func = GetRef<Function>(func_node); |
79 | if (visiting_.find(func) == visiting_.end()) { |
80 | visiting_.insert(func); |
81 | for (auto param : func_node->params) { |
82 | ExprVisitor::VisitExpr(param); |
83 | } |
84 | ExprVisitor::VisitExpr(func_node->body); |
85 | } |
86 | } |
87 | |
88 | std::unordered_set<std::string> Trace(const std::string& entry) { |
89 | called_funcs_.insert(entry); |
90 | auto main_func = module_->Lookup(entry); |
91 | VisitExpr(main_func); |
92 | return called_funcs_; |
93 | } |
94 | }; |
95 | |
96 | /*! |
97 | * \brief Remove functions that are not used. |
98 | * |
99 | * \param module The Relay module. |
100 | * \param entry_funcs The set of functions that can be entry function. |
101 | * |
102 | * \return The module with dead functions removed. |
103 | */ |
104 | IRModule RemoveUnusedFunctions(const IRModule& module, Array<runtime::String> entry_funcs) { |
105 | std::unordered_set<std::string> called_funcs{}; |
106 | for (auto entry : entry_funcs) { |
107 | auto funcs = CallTracer(module).Trace(entry); |
108 | called_funcs.insert(funcs.cbegin(), funcs.cend()); |
109 | } |
110 | auto existing_functions = module->functions; |
111 | for (auto f : existing_functions) { |
112 | auto it = called_funcs.find(f.first->name_hint); |
113 | if (it == called_funcs.end()) { |
114 | module->Remove(f.first); |
115 | } |
116 | } |
117 | return module; |
118 | } |
119 | |
120 | } // namespace vm |
121 | |
122 | namespace transform { |
123 | |
124 | Pass RemoveUnusedFunctions(Array<runtime::String> entry_functions) { |
125 | runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule m, |
126 | PassContext pc) { |
127 | return relay::vm::RemoveUnusedFunctions(m, entry_functions); |
128 | }; |
129 | return CreateModulePass(pass_func, 1, "RemoveUnusedFunctions" , {}); |
130 | } |
131 | |
132 | TVM_REGISTER_GLOBAL("relay._transform.RemoveUnusedFunctions" ).set_body_typed(RemoveUnusedFunctions); |
133 | |
134 | } // namespace transform |
135 | |
136 | } // namespace relay |
137 | } // namespace tvm |
138 | |