1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file eta_expand.cc
22 *
23 * \brief Add an abstraction over constructors and/or global variables bound to a function.
24 *
25 */
26#include <tvm/ir/type_functor.h>
27#include <tvm/relay/expr_functor.h>
28#include <tvm/relay/transform.h>
29#include <tvm/relay/type.h>
30
31namespace tvm {
32namespace relay {
33namespace eta_expand {
34
35/*!
36 * \brief mutator to replace type variables with fresh ones, while maintaining alpha equality
37 */
38class TypeVarReplacer : public TypeMutator {
39 public:
40 TypeVarReplacer() : replace_map_({}) {}
41
42 Type VisitType_(const TypeVarNode* type_var_node) final {
43 const auto type_var = GetRef<TypeVar>(type_var_node);
44 if (replace_map_.find(type_var) == replace_map_.end()) {
45 replace_map_[type_var] = TypeVar("A", Kind::kType);
46 }
47 return replace_map_[type_var];
48 }
49
50 private:
51 /*! \brief variable replacement map to remap old type vars to fresh ones */
52 std::unordered_map<TypeVar, TypeVar, ObjectPtrHash, ObjectPtrEqual> replace_map_;
53};
54
55/*!
56 * \brief mutator to perform eta expansion on all functions in a module
57 */
58class EtaExpander : public ExprMutator {
59 public:
60 explicit EtaExpander(const IRModule& mod, bool expand_constructor, bool expand_global_var)
61 : mod_(mod),
62 type_var_replacer_(TypeVarReplacer()),
63 expand_constructor_(expand_constructor),
64 expand_global_var_(expand_global_var) {
65 ICHECK(expand_constructor || expand_global_var) << "must expand at least one language feature";
66 }
67
68 IRModule Expand() {
69 for (GlobalVar global_var : mod_->GetGlobalVars()) {
70 const BaseFunc base_func = mod_->Lookup(global_var);
71 if (auto* n = base_func.as<FunctionNode>()) {
72 const Function new_func = Downcast<Function>(VisitExpr(GetRef<Function>(n)));
73 mod_->Update(global_var, new_func);
74 }
75 }
76 return mod_;
77 }
78
79 Expr VisitExpr_(const CallNode* call) final {
80 // we don't need to expand constructors when they are being called, so we
81 // prevent them being visited here
82 Expr new_op = call->op;
83 if (!call->op.as<ConstructorNode>()) {
84 new_op = VisitExpr(new_op);
85 }
86 tvm::Array<Expr> new_args;
87 for (const auto& arg : call->args) {
88 new_args.push_back(VisitExpr(arg));
89 }
90 return Call(new_op, new_args, call->attrs, call->type_args);
91 }
92
93 Expr VisitExpr_(const ConstructorNode* cons_node) final {
94 Constructor cons = GetRef<Constructor>(cons_node);
95 if (!expand_constructor_) {
96 return std::move(cons);
97 }
98 // NOTE: we only reach this case if the constructor is not being applied to any arguments
99 tvm::Array<Expr> params;
100 for (const auto& type : cons->inputs) {
101 Type param_type = type_var_replacer_.VisitType(type);
102 params.push_back(Var("eta_expand_param", param_type));
103 }
104 tvm::Array<Type> type_params;
105 TypeData adt_def = mod_->LookupTypeDef(cons->belong_to);
106 for (const auto& type_var : adt_def->type_vars) {
107 type_params.push_back(type_var_replacer_.VisitType(type_var));
108 }
109 Expr body = Call(cons, params, Attrs());
110 Type ret_type = TypeCall(cons->belong_to, type_params);
111
112 return Function(Downcast<tvm::Array<Var>>(params), body, ret_type,
113 Downcast<tvm::Array<TypeVar>>(type_params));
114 }
115
116 Expr VisitExpr_(const GlobalVarNode* gvar_node) final {
117 GlobalVar gvar = GetRef<GlobalVar>(gvar_node);
118 if (!expand_global_var_) {
119 return std::move(gvar);
120 }
121 const auto base_func = mod_->Lookup(gvar);
122 if (auto* ptr = base_func.as<FunctionNode>()) {
123 // handle relay function, skip external functions.
124 auto func = GetRef<Function>(ptr);
125 tvm::Array<Expr> params;
126 tvm::Array<Var> args;
127 for (size_t i = 0; i < func->params.size(); ++i) {
128 auto var = Var("eta_expand_param", func->params[i]->type_annotation);
129 params.push_back(var);
130 args.push_back(var);
131 }
132 return WithFields(func, args, Call(gvar, params));
133 } else {
134 return std::move(gvar);
135 }
136 }
137
138 private:
139 /*! \brief reference to module being expanded */
140 const IRModule mod_;
141 /*! \brief type variable replacer */
142 TypeVarReplacer type_var_replacer_;
143 /*! \brief whether to expand constructor nodes */
144 bool expand_constructor_;
145 /*! \brief whether to expand global variable nodes */
146 bool expand_global_var_;
147};
148
149} // namespace eta_expand
150
151namespace transform {
152
153Pass EtaExpand(bool expand_constructor, bool expand_global_var) {
154 runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule mod,
155 PassContext pc) {
156 return eta_expand::EtaExpander(mod, expand_constructor, expand_global_var).Expand();
157 };
158 return CreateModulePass(pass_func, 1, "EtaExpand", {});
159}
160
161TVM_REGISTER_GLOBAL("relay._transform.EtaExpand").set_body_typed(EtaExpand);
162
163} // namespace transform
164
165} // namespace relay
166} // namespace tvm
167