1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/ir/function.cc |
22 | * \brief Function in relay. |
23 | */ |
24 | #include <tvm/ir/type_functor.h> |
25 | #include <tvm/relay/analysis.h> |
26 | #include <tvm/relay/expr_functor.h> |
27 | #include <tvm/relay/function.h> |
28 | #include <tvm/relay/transform.h> |
29 | |
30 | namespace tvm { |
31 | namespace relay { |
32 | |
33 | Function::Function(tvm::Array<Var> params, Expr body, Type ret_type, |
34 | tvm::Array<TypeVar> type_params, DictAttrs attrs, Span span) { |
35 | ObjectPtr<FunctionNode> n = make_object<FunctionNode>(); |
36 | ICHECK(params.defined()); |
37 | ICHECK(type_params.defined()); |
38 | n->params = std::move(params); |
39 | n->body = std::move(body); |
40 | n->ret_type = std::move(ret_type); |
41 | n->type_params = std::move(type_params); |
42 | n->attrs = std::move(attrs); |
43 | n->virtual_device_ = VirtualDevice::FullyUnconstrained(); |
44 | n->span = std::move(span); |
45 | data_ = std::move(n); |
46 | } |
47 | |
48 | Function WithFields(Function function, Optional<Array<Var>> opt_params, Optional<Expr> opt_body, |
49 | Optional<Type> opt_ret_type, Optional<Array<TypeVar>> opt_ty_params, |
50 | Optional<DictAttrs> opt_attrs, Optional<VirtualDevice> opt_virtual_device, |
51 | Optional<Span> opt_span) { |
52 | Array<Var> params = opt_params.value_or(function->params); |
53 | Expr body = opt_body.value_or(function->body); |
54 | Type ret_type = opt_ret_type.value_or(function->ret_type); |
55 | Array<TypeVar> ty_params = opt_ty_params.value_or(function->type_params); |
56 | DictAttrs attrs = opt_attrs.value_or(function->attrs); |
57 | VirtualDevice virtual_device = opt_virtual_device.value_or(function->virtual_device()); |
58 | Span span = opt_span.value_or(function->span); |
59 | |
60 | bool unchanged = body.same_as(function->body) && ret_type.same_as(function->ret_type) && |
61 | attrs.same_as(function->attrs) && |
62 | virtual_device.same_as(function->virtual_device()) && |
63 | span.same_as(function->span); |
64 | |
65 | // Check that all the type params are unchanged |
66 | if (unchanged) { |
67 | bool all_ty_params_unchanged = true; |
68 | if (ty_params.size() == function->type_params.size()) { |
69 | for (size_t i = 0; i < ty_params.size(); i++) { |
70 | all_ty_params_unchanged &= ty_params[i].same_as(function->type_params[i]); |
71 | } |
72 | } else { |
73 | all_ty_params_unchanged = false; |
74 | } |
75 | unchanged &= all_ty_params_unchanged; |
76 | } |
77 | |
78 | // Check that all the params are unchanged |
79 | if (unchanged) { |
80 | bool all_params_unchanged = true; |
81 | if (params.size() == function->params.size()) { |
82 | for (size_t i = 0; i < params.size(); i++) { |
83 | all_params_unchanged &= params[i].same_as(function->params[i]); |
84 | } |
85 | } else { |
86 | all_params_unchanged = false; |
87 | } |
88 | unchanged &= all_params_unchanged; |
89 | } |
90 | |
91 | if (!unchanged) { |
92 | FunctionNode* cow_function_node = function.CopyOnWrite(); |
93 | cow_function_node->params = params; |
94 | cow_function_node->body = body; |
95 | cow_function_node->ret_type = ret_type; |
96 | cow_function_node->type_params = ty_params; |
97 | cow_function_node->attrs = attrs; |
98 | cow_function_node->virtual_device_ = virtual_device; |
99 | cow_function_node->span = span; |
100 | } |
101 | return function; |
102 | } |
103 | |
104 | FuncType FunctionNode::func_type_annotation() const { |
105 | Array<Type> param_types; |
106 | for (auto param : this->params) { |
107 | Type param_type = |
108 | (param->type_annotation.defined()) ? param->type_annotation : IncompleteType(Kind::kType); |
109 | param_types.push_back(param_type); |
110 | } |
111 | |
112 | Type ret_type = (this->ret_type.defined()) ? this->ret_type : IncompleteType(Kind::kType); |
113 | return FuncType(param_types, ret_type, this->type_params, {}); |
114 | } |
115 | |
116 | const FunctionNode* AsOptimizableFunctionNode(const BaseFunc& base_func) { |
117 | if (const auto* function_node = base_func.as<FunctionNode>()) { |
118 | if (!function_node->GetAttr<String>(attr::kCompiler).defined() && |
119 | !function_node->HasNonzeroAttr(attr::kExtern) && |
120 | !function_node->HasNonzeroAttr(attr::kSkipOptimization)) { |
121 | return function_node; |
122 | } |
123 | } |
124 | return nullptr; |
125 | } |
126 | |
127 | TVM_REGISTER_GLOBAL("relay.ir.PrintRelayModule" ) |
128 | .set_body_typed([](IRModule mod) -> Optional<String> { |
129 | for (const auto& it : mod->functions) { |
130 | if (it.second->IsInstance<FunctionNode>()) { |
131 | return PrettyPrint(mod); |
132 | } |
133 | } |
134 | return NullOpt; |
135 | }); |
136 | |
137 | TVM_REGISTER_GLOBAL("relay.ir.PrintIR" ) |
138 | .set_body_typed([](IRModule mod, String , bool show_metadata) -> bool { |
139 | for (const auto& it : mod->functions) { |
140 | if (it.second->IsInstance<FunctionNode>()) { |
141 | LOG(INFO) << "PrintIR(" << header << "):\n" << AsText(mod, show_metadata); |
142 | return true; |
143 | } |
144 | } |
145 | return false; |
146 | }); |
147 | |
148 | TVM_REGISTER_GLOBAL("relay.ir.WarnIfMalformed" ) |
149 | .set_body_typed([](const IRModule& mod, const BaseFunc& base_func) -> void { |
150 | if (const auto* relay_func = base_func.as<FunctionNode>()) { |
151 | Function func = Downcast<relay::Function>(relay::DeDup(GetRef<Function>(relay_func))); |
152 | // Type check the item before we add it to the module. |
153 | auto fv = relay::FreeVars(func); |
154 | auto ftv = relay::FreeTypeVars(func, mod); |
155 | // TODO(@jroesch): refactor to use diagnostic context |
156 | ICHECK_EQ(fv.size(), 0) << "Function:" << std::endl |
157 | << PrettyPrint(func) << std::endl |
158 | << "contains free variables: " << fv; |
159 | ICHECK_EQ(ftv.size(), 0) << "Function:" << std::endl |
160 | << PrettyPrint(func) << std::endl |
161 | << "contains free type variables: " << fv; |
162 | } |
163 | }); |
164 | TVM_REGISTER_GLOBAL("relay.ir.IRModuleAdd" ) |
165 | .set_body_typed([](IRModule mod, GlobalVar var, ObjectRef val, bool update) -> IRModule { |
166 | if (val->IsInstance<BaseFuncNode>()) { |
167 | mod->Add(var, Downcast<BaseFunc>(val), update); |
168 | } else if (val->IsInstance<GlobalVarNode>()) { |
169 | GlobalVar gv = Downcast<GlobalVar>(val); |
170 | IRModule mod_copy(make_object<IRModuleNode>(*mod.operator->())); |
171 | mod_copy = relay::transform::EtaExpand( |
172 | /* expand_constructor */ false, |
173 | /* expand_global_var */ true)(mod_copy); |
174 | auto func = mod_copy->Lookup(gv->name_hint); |
175 | mod->Add(var, Downcast<relay::Function>(func), update); |
176 | } else { |
177 | auto func = relay::Function({}, Downcast<RelayExpr>(val), Type(nullptr), {}); |
178 | mod->Add(var, func, update); |
179 | } |
180 | return mod; |
181 | }); |
182 | |
183 | TVM_REGISTER_GLOBAL("relay.ir.IRModuleUpdateWithRenamer" ) |
184 | .set_body_typed([](IRModule self, IRModule mod) -> void { |
185 | struct Renamer : relay::ExprMutator, TypeMutator { |
186 | Map<String, GlobalVar> defs; |
187 | Map<String, GlobalTypeVar> types; |
188 | std::unordered_map<int32_t, Constructor> ctors; |
189 | |
190 | Renamer(Map<String, GlobalVar> defs_one, Map<String, GlobalVar> defs_two, |
191 | Map<String, GlobalTypeVar> types_one, Map<String, GlobalTypeVar> types_two, |
192 | std::unordered_map<int32_t, Constructor> ctors_one, |
193 | std::unordered_map<int32_t, Constructor> ctor_two) { |
194 | for (auto pair : defs_one) { |
195 | defs.Set(pair.first, pair.second); |
196 | } |
197 | |
198 | for (auto pair : defs_two) { |
199 | auto it = defs.find(pair.first); |
200 | if (it == defs.end()) { |
201 | defs.Set(pair.first, pair.second); |
202 | } |
203 | } |
204 | |
205 | for (auto pair : types_one) { |
206 | types.Set(pair.first, pair.second); |
207 | } |
208 | |
209 | for (auto pair : types_two) { |
210 | auto it = types.find(pair.first); |
211 | if (it == types.end()) { |
212 | types.Set(pair.first, pair.second); |
213 | } |
214 | } |
215 | } |
216 | |
217 | relay::Expr VisitExpr_(const GlobalVarNode* node) override { |
218 | return defs.at(node->name_hint); |
219 | } |
220 | |
221 | Type VisitType_(const GlobalTypeVarNode* node) override { |
222 | return types.at(node->name_hint); |
223 | } |
224 | }; |
225 | |
226 | Renamer renamer(self->global_var_map_, mod->global_var_map_, self->global_type_var_map_, |
227 | mod->global_type_var_map_, self->constructor_tag_map_, |
228 | mod->constructor_tag_map_); |
229 | |
230 | self->global_var_map_ = renamer.defs; |
231 | self->global_type_var_map_ = renamer.types; |
232 | self->constructor_tag_map_ = renamer.ctors; |
233 | |
234 | for (auto pair : mod->type_definitions) { |
235 | auto tvar = renamer.types.at(pair.first->name_hint); |
236 | auto ty = renamer.ExprMutator::VisitType(pair.second); |
237 | self->AddTypeDefUnchecked(tvar, Downcast<TypeData>(ty), true); |
238 | } |
239 | |
240 | for (auto pair : mod->functions) { |
241 | if (auto rfn = pair.second.as<relay::FunctionNode>()) { |
242 | auto gvar = renamer.defs.at(pair.first->name_hint); |
243 | auto fn = renamer.VisitExpr(GetRef<relay::Function>(rfn)); |
244 | self->AddUnchecked(gvar, Downcast<BaseFunc>(fn)); |
245 | } else { |
246 | // TODO(@jroesch): rename into IRModule. |
247 | self->AddUnchecked(pair.first, pair.second); |
248 | } |
249 | } |
250 | }); |
251 | |
252 | TVM_REGISTER_GLOBAL("relay.ir.FunctionFromExprInContext" ) |
253 | .set_body_typed([](RelayExpr expr, IRModule mod) -> Function { |
254 | return Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod), {}); |
255 | }); |
256 | |
257 | TVM_REGISTER_GLOBAL("relay.ir.FuncWithAttr" ) |
258 | .set_body_typed([](BaseFunc func, String key, ObjectRef value) -> Optional<Function> { |
259 | if (func->IsInstance<relay::FunctionNode>()) { |
260 | return WithAttr(Downcast<relay::Function>(std::move(func)), key, value); |
261 | } |
262 | return NullOpt; |
263 | }); |
264 | |
265 | TVM_REGISTER_GLOBAL("relay.ir.FuncWithoutAttr" ) |
266 | .set_body_typed([](BaseFunc func, String key) -> Optional<Function> { |
267 | if (func->IsInstance<relay::FunctionNode>()) { |
268 | return WithoutAttr(Downcast<relay::Function>(std::move(func)), key); |
269 | } |
270 | return NullOpt; |
271 | }); |
272 | |
273 | TVM_REGISTER_NODE_TYPE(FunctionNode); |
274 | |
275 | TVM_REGISTER_GLOBAL("relay.ir.Function" ) |
276 | .set_body_typed([](tvm::Array<Var> params, Expr body, Type ret_type, |
277 | tvm::Array<TypeVar> ty_params, tvm::DictAttrs attrs, Span span) { |
278 | return Function(params, body, ret_type, ty_params, attrs, span); |
279 | }); |
280 | TVM_REGISTER_GLOBAL("relay.ir.FunctionWithFields" ) |
281 | .set_body_typed([](Function function, Optional<Array<Var>> opt_params, Optional<Expr> opt_body, |
282 | Optional<Type> opt_ret_type, Optional<Array<TypeVar>> opt_ty_params, |
283 | Optional<DictAttrs> opt_attrs, Optional<VirtualDevice> opt_virtual_device, |
284 | Optional<Span> opt_span) { |
285 | return WithFields(function, opt_params, opt_body, opt_ret_type, opt_ty_params, opt_attrs, |
286 | opt_virtual_device, opt_span); |
287 | }); |
288 | |
289 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
290 | .set_dispatch<FunctionNode>([](const ObjectRef& ref, ReprPrinter* p) { |
291 | // TODO(@jroesch): previously this had a debug printer, the debug printer |
292 | // can cause exponential behavior and is currently dangerous, for these |
293 | // cases we need some kind of de-duping. |
294 | // |
295 | // See old implementation: |
296 | // |
297 | // auto* node = static_cast<const FunctionNode*>(ref.get()); |
298 | // p->stream << "FunctionNode(" << node->params << ", " << node->ret_type << ", " << |
299 | // node->body |
300 | // << ", " << node->type_params << ", " << node->attrs << ")"; |
301 | p->stream << PrettyPrint(ref); |
302 | }); |
303 | |
304 | } // namespace relay |
305 | } // namespace tvm |
306 | |