1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file eta_expand.cc |
22 | * |
23 | * \brief Add an abstraction over constructors and/or global variables bound to a function. |
24 | * |
25 | */ |
26 | #include <tvm/ir/type_functor.h> |
27 | #include <tvm/relay/expr_functor.h> |
28 | #include <tvm/relay/transform.h> |
29 | #include <tvm/relay/type.h> |
30 | |
31 | namespace tvm { |
32 | namespace relay { |
33 | namespace eta_expand { |
34 | |
35 | /*! |
36 | * \brief mutator to replace type variables with fresh ones, while maintaining alpha equality |
37 | */ |
38 | class TypeVarReplacer : public TypeMutator { |
39 | public: |
40 | TypeVarReplacer() : replace_map_({}) {} |
41 | |
42 | Type VisitType_(const TypeVarNode* type_var_node) final { |
43 | const auto type_var = GetRef<TypeVar>(type_var_node); |
44 | if (replace_map_.find(type_var) == replace_map_.end()) { |
45 | replace_map_[type_var] = TypeVar("A" , Kind::kType); |
46 | } |
47 | return replace_map_[type_var]; |
48 | } |
49 | |
50 | private: |
51 | /*! \brief variable replacement map to remap old type vars to fresh ones */ |
52 | std::unordered_map<TypeVar, TypeVar, ObjectPtrHash, ObjectPtrEqual> replace_map_; |
53 | }; |
54 | |
55 | /*! |
56 | * \brief mutator to perform eta expansion on all functions in a module |
57 | */ |
58 | class EtaExpander : public ExprMutator { |
59 | public: |
60 | explicit EtaExpander(const IRModule& mod, bool expand_constructor, bool expand_global_var) |
61 | : mod_(mod), |
62 | type_var_replacer_(TypeVarReplacer()), |
63 | expand_constructor_(expand_constructor), |
64 | expand_global_var_(expand_global_var) { |
65 | ICHECK(expand_constructor || expand_global_var) << "must expand at least one language feature" ; |
66 | } |
67 | |
68 | IRModule Expand() { |
69 | for (GlobalVar global_var : mod_->GetGlobalVars()) { |
70 | const BaseFunc base_func = mod_->Lookup(global_var); |
71 | if (auto* n = base_func.as<FunctionNode>()) { |
72 | const Function new_func = Downcast<Function>(VisitExpr(GetRef<Function>(n))); |
73 | mod_->Update(global_var, new_func); |
74 | } |
75 | } |
76 | return mod_; |
77 | } |
78 | |
79 | Expr VisitExpr_(const CallNode* call) final { |
80 | // we don't need to expand constructors when they are being called, so we |
81 | // prevent them being visited here |
82 | Expr new_op = call->op; |
83 | if (!call->op.as<ConstructorNode>()) { |
84 | new_op = VisitExpr(new_op); |
85 | } |
86 | tvm::Array<Expr> new_args; |
87 | for (const auto& arg : call->args) { |
88 | new_args.push_back(VisitExpr(arg)); |
89 | } |
90 | return Call(new_op, new_args, call->attrs, call->type_args); |
91 | } |
92 | |
93 | Expr VisitExpr_(const ConstructorNode* cons_node) final { |
94 | Constructor cons = GetRef<Constructor>(cons_node); |
95 | if (!expand_constructor_) { |
96 | return std::move(cons); |
97 | } |
98 | // NOTE: we only reach this case if the constructor is not being applied to any arguments |
99 | tvm::Array<Expr> params; |
100 | for (const auto& type : cons->inputs) { |
101 | Type param_type = type_var_replacer_.VisitType(type); |
102 | params.push_back(Var("eta_expand_param" , param_type)); |
103 | } |
104 | tvm::Array<Type> type_params; |
105 | TypeData adt_def = mod_->LookupTypeDef(cons->belong_to); |
106 | for (const auto& type_var : adt_def->type_vars) { |
107 | type_params.push_back(type_var_replacer_.VisitType(type_var)); |
108 | } |
109 | Expr body = Call(cons, params, Attrs()); |
110 | Type ret_type = TypeCall(cons->belong_to, type_params); |
111 | |
112 | return Function(Downcast<tvm::Array<Var>>(params), body, ret_type, |
113 | Downcast<tvm::Array<TypeVar>>(type_params)); |
114 | } |
115 | |
116 | Expr VisitExpr_(const GlobalVarNode* gvar_node) final { |
117 | GlobalVar gvar = GetRef<GlobalVar>(gvar_node); |
118 | if (!expand_global_var_) { |
119 | return std::move(gvar); |
120 | } |
121 | const auto base_func = mod_->Lookup(gvar); |
122 | if (auto* ptr = base_func.as<FunctionNode>()) { |
123 | // handle relay function, skip external functions. |
124 | auto func = GetRef<Function>(ptr); |
125 | tvm::Array<Expr> params; |
126 | tvm::Array<Var> args; |
127 | for (size_t i = 0; i < func->params.size(); ++i) { |
128 | auto var = Var("eta_expand_param" , func->params[i]->type_annotation); |
129 | params.push_back(var); |
130 | args.push_back(var); |
131 | } |
132 | return WithFields(func, args, Call(gvar, params)); |
133 | } else { |
134 | return std::move(gvar); |
135 | } |
136 | } |
137 | |
138 | private: |
139 | /*! \brief reference to module being expanded */ |
140 | const IRModule mod_; |
141 | /*! \brief type variable replacer */ |
142 | TypeVarReplacer type_var_replacer_; |
143 | /*! \brief whether to expand constructor nodes */ |
144 | bool expand_constructor_; |
145 | /*! \brief whether to expand global variable nodes */ |
146 | bool expand_global_var_; |
147 | }; |
148 | |
149 | } // namespace eta_expand |
150 | |
151 | namespace transform { |
152 | |
153 | Pass EtaExpand(bool expand_constructor, bool expand_global_var) { |
154 | runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule mod, |
155 | PassContext pc) { |
156 | return eta_expand::EtaExpander(mod, expand_constructor, expand_global_var).Expand(); |
157 | }; |
158 | return CreateModulePass(pass_func, 1, "EtaExpand" , {}); |
159 | } |
160 | |
161 | TVM_REGISTER_GLOBAL("relay._transform.EtaExpand" ).set_body_typed(EtaExpand); |
162 | |
163 | } // namespace transform |
164 | |
165 | } // namespace relay |
166 | } // namespace tvm |
167 | |