1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/ir/function.cc
22 * \brief Function in relay.
23 */
24#include <tvm/ir/type_functor.h>
25#include <tvm/relay/analysis.h>
26#include <tvm/relay/expr_functor.h>
27#include <tvm/relay/function.h>
28#include <tvm/relay/transform.h>
29
30namespace tvm {
31namespace relay {
32
33Function::Function(tvm::Array<Var> params, Expr body, Type ret_type,
34 tvm::Array<TypeVar> type_params, DictAttrs attrs, Span span) {
35 ObjectPtr<FunctionNode> n = make_object<FunctionNode>();
36 ICHECK(params.defined());
37 ICHECK(type_params.defined());
38 n->params = std::move(params);
39 n->body = std::move(body);
40 n->ret_type = std::move(ret_type);
41 n->type_params = std::move(type_params);
42 n->attrs = std::move(attrs);
43 n->virtual_device_ = VirtualDevice::FullyUnconstrained();
44 n->span = std::move(span);
45 data_ = std::move(n);
46}
47
48Function WithFields(Function function, Optional<Array<Var>> opt_params, Optional<Expr> opt_body,
49 Optional<Type> opt_ret_type, Optional<Array<TypeVar>> opt_ty_params,
50 Optional<DictAttrs> opt_attrs, Optional<VirtualDevice> opt_virtual_device,
51 Optional<Span> opt_span) {
52 Array<Var> params = opt_params.value_or(function->params);
53 Expr body = opt_body.value_or(function->body);
54 Type ret_type = opt_ret_type.value_or(function->ret_type);
55 Array<TypeVar> ty_params = opt_ty_params.value_or(function->type_params);
56 DictAttrs attrs = opt_attrs.value_or(function->attrs);
57 VirtualDevice virtual_device = opt_virtual_device.value_or(function->virtual_device());
58 Span span = opt_span.value_or(function->span);
59
60 bool unchanged = body.same_as(function->body) && ret_type.same_as(function->ret_type) &&
61 attrs.same_as(function->attrs) &&
62 virtual_device.same_as(function->virtual_device()) &&
63 span.same_as(function->span);
64
65 // Check that all the type params are unchanged
66 if (unchanged) {
67 bool all_ty_params_unchanged = true;
68 if (ty_params.size() == function->type_params.size()) {
69 for (size_t i = 0; i < ty_params.size(); i++) {
70 all_ty_params_unchanged &= ty_params[i].same_as(function->type_params[i]);
71 }
72 } else {
73 all_ty_params_unchanged = false;
74 }
75 unchanged &= all_ty_params_unchanged;
76 }
77
78 // Check that all the params are unchanged
79 if (unchanged) {
80 bool all_params_unchanged = true;
81 if (params.size() == function->params.size()) {
82 for (size_t i = 0; i < params.size(); i++) {
83 all_params_unchanged &= params[i].same_as(function->params[i]);
84 }
85 } else {
86 all_params_unchanged = false;
87 }
88 unchanged &= all_params_unchanged;
89 }
90
91 if (!unchanged) {
92 FunctionNode* cow_function_node = function.CopyOnWrite();
93 cow_function_node->params = params;
94 cow_function_node->body = body;
95 cow_function_node->ret_type = ret_type;
96 cow_function_node->type_params = ty_params;
97 cow_function_node->attrs = attrs;
98 cow_function_node->virtual_device_ = virtual_device;
99 cow_function_node->span = span;
100 }
101 return function;
102}
103
104FuncType FunctionNode::func_type_annotation() const {
105 Array<Type> param_types;
106 for (auto param : this->params) {
107 Type param_type =
108 (param->type_annotation.defined()) ? param->type_annotation : IncompleteType(Kind::kType);
109 param_types.push_back(param_type);
110 }
111
112 Type ret_type = (this->ret_type.defined()) ? this->ret_type : IncompleteType(Kind::kType);
113 return FuncType(param_types, ret_type, this->type_params, {});
114}
115
116const FunctionNode* AsOptimizableFunctionNode(const BaseFunc& base_func) {
117 if (const auto* function_node = base_func.as<FunctionNode>()) {
118 if (!function_node->GetAttr<String>(attr::kCompiler).defined() &&
119 !function_node->HasNonzeroAttr(attr::kExtern) &&
120 !function_node->HasNonzeroAttr(attr::kSkipOptimization)) {
121 return function_node;
122 }
123 }
124 return nullptr;
125}
126
127TVM_REGISTER_GLOBAL("relay.ir.PrintRelayModule")
128 .set_body_typed([](IRModule mod) -> Optional<String> {
129 for (const auto& it : mod->functions) {
130 if (it.second->IsInstance<FunctionNode>()) {
131 return PrettyPrint(mod);
132 }
133 }
134 return NullOpt;
135 });
136
137TVM_REGISTER_GLOBAL("relay.ir.PrintIR")
138 .set_body_typed([](IRModule mod, String header, bool show_metadata) -> bool {
139 for (const auto& it : mod->functions) {
140 if (it.second->IsInstance<FunctionNode>()) {
141 LOG(INFO) << "PrintIR(" << header << "):\n" << AsText(mod, show_metadata);
142 return true;
143 }
144 }
145 return false;
146 });
147
148TVM_REGISTER_GLOBAL("relay.ir.WarnIfMalformed")
149 .set_body_typed([](const IRModule& mod, const BaseFunc& base_func) -> void {
150 if (const auto* relay_func = base_func.as<FunctionNode>()) {
151 Function func = Downcast<relay::Function>(relay::DeDup(GetRef<Function>(relay_func)));
152 // Type check the item before we add it to the module.
153 auto fv = relay::FreeVars(func);
154 auto ftv = relay::FreeTypeVars(func, mod);
155 // TODO(@jroesch): refactor to use diagnostic context
156 ICHECK_EQ(fv.size(), 0) << "Function:" << std::endl
157 << PrettyPrint(func) << std::endl
158 << "contains free variables: " << fv;
159 ICHECK_EQ(ftv.size(), 0) << "Function:" << std::endl
160 << PrettyPrint(func) << std::endl
161 << "contains free type variables: " << fv;
162 }
163 });
164TVM_REGISTER_GLOBAL("relay.ir.IRModuleAdd")
165 .set_body_typed([](IRModule mod, GlobalVar var, ObjectRef val, bool update) -> IRModule {
166 if (val->IsInstance<BaseFuncNode>()) {
167 mod->Add(var, Downcast<BaseFunc>(val), update);
168 } else if (val->IsInstance<GlobalVarNode>()) {
169 GlobalVar gv = Downcast<GlobalVar>(val);
170 IRModule mod_copy(make_object<IRModuleNode>(*mod.operator->()));
171 mod_copy = relay::transform::EtaExpand(
172 /* expand_constructor */ false,
173 /* expand_global_var */ true)(mod_copy);
174 auto func = mod_copy->Lookup(gv->name_hint);
175 mod->Add(var, Downcast<relay::Function>(func), update);
176 } else {
177 auto func = relay::Function({}, Downcast<RelayExpr>(val), Type(nullptr), {});
178 mod->Add(var, func, update);
179 }
180 return mod;
181 });
182
183TVM_REGISTER_GLOBAL("relay.ir.IRModuleUpdateWithRenamer")
184 .set_body_typed([](IRModule self, IRModule mod) -> void {
185 struct Renamer : relay::ExprMutator, TypeMutator {
186 Map<String, GlobalVar> defs;
187 Map<String, GlobalTypeVar> types;
188 std::unordered_map<int32_t, Constructor> ctors;
189
190 Renamer(Map<String, GlobalVar> defs_one, Map<String, GlobalVar> defs_two,
191 Map<String, GlobalTypeVar> types_one, Map<String, GlobalTypeVar> types_two,
192 std::unordered_map<int32_t, Constructor> ctors_one,
193 std::unordered_map<int32_t, Constructor> ctor_two) {
194 for (auto pair : defs_one) {
195 defs.Set(pair.first, pair.second);
196 }
197
198 for (auto pair : defs_two) {
199 auto it = defs.find(pair.first);
200 if (it == defs.end()) {
201 defs.Set(pair.first, pair.second);
202 }
203 }
204
205 for (auto pair : types_one) {
206 types.Set(pair.first, pair.second);
207 }
208
209 for (auto pair : types_two) {
210 auto it = types.find(pair.first);
211 if (it == types.end()) {
212 types.Set(pair.first, pair.second);
213 }
214 }
215 }
216
217 relay::Expr VisitExpr_(const GlobalVarNode* node) override {
218 return defs.at(node->name_hint);
219 }
220
221 Type VisitType_(const GlobalTypeVarNode* node) override {
222 return types.at(node->name_hint);
223 }
224 };
225
226 Renamer renamer(self->global_var_map_, mod->global_var_map_, self->global_type_var_map_,
227 mod->global_type_var_map_, self->constructor_tag_map_,
228 mod->constructor_tag_map_);
229
230 self->global_var_map_ = renamer.defs;
231 self->global_type_var_map_ = renamer.types;
232 self->constructor_tag_map_ = renamer.ctors;
233
234 for (auto pair : mod->type_definitions) {
235 auto tvar = renamer.types.at(pair.first->name_hint);
236 auto ty = renamer.ExprMutator::VisitType(pair.second);
237 self->AddTypeDefUnchecked(tvar, Downcast<TypeData>(ty), true);
238 }
239
240 for (auto pair : mod->functions) {
241 if (auto rfn = pair.second.as<relay::FunctionNode>()) {
242 auto gvar = renamer.defs.at(pair.first->name_hint);
243 auto fn = renamer.VisitExpr(GetRef<relay::Function>(rfn));
244 self->AddUnchecked(gvar, Downcast<BaseFunc>(fn));
245 } else {
246 // TODO(@jroesch): rename into IRModule.
247 self->AddUnchecked(pair.first, pair.second);
248 }
249 }
250 });
251
252TVM_REGISTER_GLOBAL("relay.ir.FunctionFromExprInContext")
253 .set_body_typed([](RelayExpr expr, IRModule mod) -> Function {
254 return Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod), {});
255 });
256
257TVM_REGISTER_GLOBAL("relay.ir.FuncWithAttr")
258 .set_body_typed([](BaseFunc func, String key, ObjectRef value) -> Optional<Function> {
259 if (func->IsInstance<relay::FunctionNode>()) {
260 return WithAttr(Downcast<relay::Function>(std::move(func)), key, value);
261 }
262 return NullOpt;
263 });
264
265TVM_REGISTER_GLOBAL("relay.ir.FuncWithoutAttr")
266 .set_body_typed([](BaseFunc func, String key) -> Optional<Function> {
267 if (func->IsInstance<relay::FunctionNode>()) {
268 return WithoutAttr(Downcast<relay::Function>(std::move(func)), key);
269 }
270 return NullOpt;
271 });
272
273TVM_REGISTER_NODE_TYPE(FunctionNode);
274
275TVM_REGISTER_GLOBAL("relay.ir.Function")
276 .set_body_typed([](tvm::Array<Var> params, Expr body, Type ret_type,
277 tvm::Array<TypeVar> ty_params, tvm::DictAttrs attrs, Span span) {
278 return Function(params, body, ret_type, ty_params, attrs, span);
279 });
280TVM_REGISTER_GLOBAL("relay.ir.FunctionWithFields")
281 .set_body_typed([](Function function, Optional<Array<Var>> opt_params, Optional<Expr> opt_body,
282 Optional<Type> opt_ret_type, Optional<Array<TypeVar>> opt_ty_params,
283 Optional<DictAttrs> opt_attrs, Optional<VirtualDevice> opt_virtual_device,
284 Optional<Span> opt_span) {
285 return WithFields(function, opt_params, opt_body, opt_ret_type, opt_ty_params, opt_attrs,
286 opt_virtual_device, opt_span);
287 });
288
289TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
290 .set_dispatch<FunctionNode>([](const ObjectRef& ref, ReprPrinter* p) {
291 // TODO(@jroesch): previously this had a debug printer, the debug printer
292 // can cause exponential behavior and is currently dangerous, for these
293 // cases we need some kind of de-duping.
294 //
295 // See old implementation:
296 //
297 // auto* node = static_cast<const FunctionNode*>(ref.get());
298 // p->stream << "FunctionNode(" << node->params << ", " << node->ret_type << ", " <<
299 // node->body
300 // << ", " << node->type_params << ", " << node->attrs << ")";
301 p->stream << PrettyPrint(ref);
302 });
303
304} // namespace relay
305} // namespace tvm
306