1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 *
14 * software distributed under the License is distributed on an
15 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16 * KIND, either express or implied. See the License for the
17 * specific language governing permissions and limitations
18 * under the License.
19 */
20
21/*!
22 * \file tvm/relay/backend/vm/remove_unused_funcs.cc
23 * \brief Remove unused global relay functions in a relay module.
24 */
25
26#include <tvm/relay/analysis.h>
27#include <tvm/relay/attrs/annotation.h>
28#include <tvm/relay/expr.h>
29#include <tvm/relay/expr_functor.h>
30#include <tvm/relay/transform.h>
31#include <tvm/runtime/logging.h>
32
33#include <iostream>
34#include <unordered_set>
35#include <vector>
36
37#include "../../op/call/call.h"
38
39namespace tvm {
40namespace relay {
41namespace vm {
42
43/**
44 * \brief Detects all the functions that can be possibly called by entry function.
45 */
46struct CallTracer : ExprVisitor {
47 IRModule module_;
48
49 // Record the names of all encountered functions
50 std::unordered_set<std::string> called_funcs_;
51
52 // Record the expressions that are being visited
53 std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> visiting_;
54
55 explicit CallTracer(const IRModule& module) : module_{module}, called_funcs_{}, visiting_{} {}
56
57 void VisitExpr_(const GlobalVarNode* op) final {
58 called_funcs_.insert(op->name_hint);
59 auto func = module_->Lookup(op->name_hint);
60 if (const auto* function_node = func.as<FunctionNode>()) {
61 VisitExpr(GetRef<Function>(function_node));
62 }
63 // else: Don't visit PrimFuncs -- we don't need to collect any tir.Calls therein.
64 }
65
66 void VisitExpr_(const CallNode* call_node) final {
67 // TODO(mbs): Cleanup shape functions.
68 CallLoweredProps props = GetCallLoweredProps(call_node);
69 if (props.lowered_func.defined() && props.attrs.metadata.count("prim_shape_fn_var")) {
70 auto callee = Downcast<GlobalVar>(props.attrs.metadata["prim_shape_fn_var"]);
71 // We are implicitly calling the shape function *in addition to* the callee.
72 called_funcs_.insert(callee->name_hint);
73 }
74 ExprVisitor::VisitExpr_(call_node);
75 }
76
77 void VisitExpr_(const FunctionNode* func_node) final {
78 auto func = GetRef<Function>(func_node);
79 if (visiting_.find(func) == visiting_.end()) {
80 visiting_.insert(func);
81 for (auto param : func_node->params) {
82 ExprVisitor::VisitExpr(param);
83 }
84 ExprVisitor::VisitExpr(func_node->body);
85 }
86 }
87
88 std::unordered_set<std::string> Trace(const std::string& entry) {
89 called_funcs_.insert(entry);
90 auto main_func = module_->Lookup(entry);
91 VisitExpr(main_func);
92 return called_funcs_;
93 }
94};
95
96/*!
97 * \brief Remove functions that are not used.
98 *
99 * \param module The Relay module.
100 * \param entry_funcs The set of functions that can be entry function.
101 *
102 * \return The module with dead functions removed.
103 */
104IRModule RemoveUnusedFunctions(const IRModule& module, Array<runtime::String> entry_funcs) {
105 std::unordered_set<std::string> called_funcs{};
106 for (auto entry : entry_funcs) {
107 auto funcs = CallTracer(module).Trace(entry);
108 called_funcs.insert(funcs.cbegin(), funcs.cend());
109 }
110 auto existing_functions = module->functions;
111 for (auto f : existing_functions) {
112 auto it = called_funcs.find(f.first->name_hint);
113 if (it == called_funcs.end()) {
114 module->Remove(f.first);
115 }
116 }
117 return module;
118}
119
120} // namespace vm
121
122namespace transform {
123
124Pass RemoveUnusedFunctions(Array<runtime::String> entry_functions) {
125 runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule m,
126 PassContext pc) {
127 return relay::vm::RemoveUnusedFunctions(m, entry_functions);
128 };
129 return CreateModulePass(pass_func, 1, "RemoveUnusedFunctions", {});
130}
131
132TVM_REGISTER_GLOBAL("relay._transform.RemoveUnusedFunctions").set_body_typed(RemoveUnusedFunctions);
133
134} // namespace transform
135
136} // namespace relay
137} // namespace tvm
138