1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file type_infer.cc |
22 | * \brief Relay type inference and checking. |
23 | * |
24 | * This file implements one of the most important passes to the |
25 | * Relay IR. In order to do many transformations and generate the |
26 | * most efficient code we need to obtain type information for the |
27 | * IR. |
28 | * |
29 | * Similar to previous computation graph based IRs, the Relay IR leaves |
30 | * type information implicit and computes types by performing program |
31 | * analysis. |
32 | * |
33 | * Given an expression `e` this pass infers a type `t` for |
34 | * the expression as well as simultaneously checking the property `e : t` |
35 | * (i.e., we can show e has type t). |
36 | * |
37 | * If we can not infer a type or there is a conflicting |
38 | * constraint it will emit errors. |
39 | */ |
40 | |
41 | #include <tvm/ir/transform.h> |
42 | #include <tvm/ir/type_functor.h> |
43 | #include <tvm/relay/analysis.h> |
44 | #include <tvm/relay/dataflow_matcher.h> |
45 | #include <tvm/relay/expr_functor.h> |
46 | #include <tvm/relay/pattern_functor.h> |
47 | #include <tvm/relay/transform.h> |
48 | |
49 | #include "../analysis/type_solver.h" |
50 | #include "pass_utils.h" |
51 | |
52 | namespace tvm { |
53 | namespace relay { |
54 | |
55 | // Necessary deferred relation for TupleGetItem |
56 | struct TupleGetItemAttrs : public tvm::AttrsNode<TupleGetItemAttrs> { |
57 | int index; |
58 | |
59 | TVM_DECLARE_ATTRS(TupleGetItemAttrs, "relay.attrs.TupleGetItemAttrs" ) { TVM_ATTR_FIELD(index); } |
60 | }; |
61 | |
62 | bool TupleGetItemRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
63 | const TypeReporter& reporter) { |
64 | ICHECK_EQ(types.size(), 2); |
65 | if (types[0].as<IncompleteTypeNode>()) return false; |
66 | const auto* data = types[0].as<TupleTypeNode>(); |
67 | ICHECK(data != nullptr) << "TupleGetItem expect input type to be TupleType " |
68 | << " get " << types[0] << " instead" ; |
69 | const auto* param = attrs.as<TupleGetItemAttrs>(); |
70 | ICHECK(param != nullptr); |
71 | ICHECK_GE(param->index, 0); |
72 | ICHECK_LT(param->index, data->fields.size()); |
73 | reporter->Assign(types[1], data->fields[param->index]); |
74 | return true; |
75 | } |
76 | |
77 | TVM_REGISTER_NODE_TYPE(TupleGetItemAttrs); |
78 | TVM_REGISTER_GLOBAL("tvm.relay.type_relation.TupleGetItem" ).set_body_typed(TupleGetItemRel); |
79 | |
80 | struct ResolvedTypeInfo { |
81 | explicit ResolvedTypeInfo(Type checked_type, Array<Type> type_args) |
82 | : checked_type(checked_type), type_args(type_args) {} |
83 | ResolvedTypeInfo() {} |
84 | |
85 | Type checked_type; |
86 | // Only allocated when the expression is a call. |
87 | |
88 | Array<Type> type_args = Array<Type>(ObjectPtr<Object>(nullptr)); |
89 | }; |
90 | |
91 | // |
92 | // The inference algorithm can roughly be divided into three stages: |
93 | // - Populate the constraints by visiting the expression (TypeInferencer.GetType) |
94 | // - solver.AddConstraint and solver.Unify are called to populate the necessary constraints |
95 | // - Solve the constraints (solver_.Solve) |
96 | // - Recreate expression with the resolved checked_type (Resolver.VisitExpr) |
97 | // |
98 | class TypeInferencer : private ExprFunctor<Type(const Expr&)>, |
99 | private PatternFunctor<void(const Pattern&, const Type&)> { |
100 | public: |
101 | // constructors |
102 | |
103 | explicit TypeInferencer(IRModule mod, DiagnosticContext diag_ctx) |
104 | : mod_(mod), diag_ctx(diag_ctx), solver_(GlobalVar(), diag_ctx) { |
105 | ICHECK(mod.defined()) << "Module must not be null in the type inferencer." ; |
106 | } |
107 | |
108 | // Infer the types inside of a function. |
109 | Expr Infer(GlobalVar var, Function expr); |
110 | |
111 | private: |
112 | // type resolver that maps back to type |
113 | class Resolver; |
114 | // internal environment |
115 | IRModule mod_; |
116 | |
117 | // The current function being type checked. |
118 | GlobalVar current_func_; |
119 | |
120 | /*! \brief The diagnostic context. */ |
121 | DiagnosticContext diag_ctx; |
122 | |
123 | // map from expression to checked type |
124 | // type inferencer will populate it up |
125 | std::unordered_map<Expr, ResolvedTypeInfo, ObjectPtrHash, ObjectPtrEqual> type_map_; |
126 | |
127 | // The solver used by the inferencer. |
128 | TypeSolver solver_; |
129 | // relation function |
130 | TypeRelationFn tuple_getitem_rel_; |
131 | TypeRelationFn make_tuple_rel_; |
132 | |
133 | /*! \brief Internal map used for memoization. */ |
134 | std::unordered_map<Expr, Type, ObjectPtrHash, ObjectPtrEqual> memo_; |
135 | |
136 | void VisitLeaf(const Expr& expr) { |
137 | if (!memo_.count(expr)) { |
138 | Type ret = this->DispatchVisitExpr(expr); |
139 | memo_[expr] = ret; |
140 | } |
141 | } |
142 | |
143 | bool CheckVisited(const Expr& expr) { |
144 | if (memo_.count(expr)) { |
145 | return true; |
146 | } else { |
147 | return false; |
148 | } |
149 | } |
150 | |
151 | Type DispatchVisitExpr(const Expr& expr) { return ExprFunctor::VisitExpr(expr); } |
152 | |
153 | Type VisitExpr(const Expr& expr) final { |
154 | auto fcheck_visited = [this](const Expr& expr) { return this->CheckVisited(expr); }; |
155 | auto fvisit_leaf = [this](const Expr& expr) { return this->VisitLeaf(expr); }; |
156 | if (memo_.count(expr)) { |
157 | return memo_[expr]; |
158 | } else { |
159 | ExpandDataflow(expr, fcheck_visited, fvisit_leaf); |
160 | return memo_[expr]; |
161 | } |
162 | } |
163 | |
164 | // Perform unification on two types and report the error at the expression |
165 | // or the span of the expression. |
166 | Type Unify(const Type& t1, const Type& t2, const Span& span, bool assign_lhs = true, |
167 | bool assign_rhs = true) { |
168 | try { |
169 | return solver_.Unify(t1, t2, span, assign_lhs, assign_rhs); |
170 | } catch (const Error& e) { |
171 | this->EmitFatal(Diagnostic::Error(span) |
172 | << "Error unifying `" << t1 << "` and `" << t2 << "`: " << e.what()); |
173 | return Type(); |
174 | } |
175 | } |
176 | |
177 | // Lazily get type for expr |
178 | // expression, we will populate it now, and return the result. |
179 | Type GetType(const Expr& expr) { |
180 | auto it = type_map_.find(expr); |
181 | if (it != type_map_.end() && it->second.checked_type.defined()) { |
182 | return it->second.checked_type; |
183 | } |
184 | Type ret = this->VisitExpr(expr); |
185 | ICHECK(ret.defined()) << "expression:" << std::endl << PrettyPrint(expr); |
186 | KindCheck(ret, mod_, this->diag_ctx); |
187 | ResolvedTypeInfo& rti = type_map_[expr]; |
188 | rti.checked_type = ret; |
189 | return ret; |
190 | } |
191 | |
192 | void EmitFatal(const Diagnostic& diag) { this->diag_ctx.EmitFatal(diag); } |
193 | |
194 | // Visitor Logic |
195 | Type VisitExpr_(const VarNode* op) final { |
196 | if (op->type_annotation.defined()) { |
197 | return op->type_annotation; |
198 | } else { |
199 | return IncompleteType(Kind::kType); |
200 | } |
201 | } |
202 | |
203 | Type VisitExpr_(const GlobalVarNode* op) final { |
204 | GlobalVar var = GetRef<GlobalVar>(op); |
205 | if (!mod_.defined()) { |
206 | this->EmitFatal(Diagnostic::Error(op->span) << "Cannot do type inference on global variables " |
207 | << "without a module" ); |
208 | } |
209 | if (mod_->ContainGlobalVar(var->name_hint)) { |
210 | BaseFunc func = mod_->Lookup(var->name_hint); |
211 | |
212 | if (const auto* function_node = func.as<FunctionNode>()) { |
213 | VLOG(1) << "global var '" << op->name_hint << "' bound to Function" ; |
214 | return function_node->checked_type(); |
215 | } else { |
216 | VLOG(1) << "global var '" << op->name_hint << "' bound to PrimFunc" ; |
217 | return op->checked_type_; |
218 | } |
219 | } else { |
220 | // TODO(mbs): extern function cleanup |
221 | // Assume the function is extern thus no longer in the IRModule. |
222 | VLOG(1) << "global var '" << op->name_hint << "' not in module" ; |
223 | return op->checked_type_; |
224 | } |
225 | } |
226 | |
227 | Type VisitExpr_(const ConstantNode* op) final { return op->tensor_type(); } |
228 | |
229 | Type VisitExpr_(const TupleNode* op) final { |
230 | Array<Type> types; |
231 | for (Expr field : op->fields) { |
232 | types.push_back(GetType(field)); |
233 | } |
234 | return TupleType(types); |
235 | } |
236 | |
237 | Type VisitExpr_(const TupleGetItemNode* op) final { |
238 | if (!tuple_getitem_rel_.defined()) { |
239 | tuple_getitem_rel_ = |
240 | Downcast<TypeRelationFn>(EnvFunc::Get("tvm.relay.type_relation.TupleGetItem" )); |
241 | } |
242 | Type tuple_type = GetType(op->tuple); |
243 | Type rtype = IncompleteType(Kind::kType); |
244 | auto attrs = make_object<TupleGetItemAttrs>(); |
245 | attrs->index = op->index; |
246 | solver_.AddConstraint(TypeRelation(tuple_getitem_rel_, {tuple_type, rtype}, 1, Attrs(attrs)), |
247 | op->span); |
248 | return rtype; |
249 | } |
250 | |
251 | void VisitPattern_(const PatternConstructorNode* con, const Type& t) { |
252 | ICHECK(mod_.defined()) << "Cannot do type inference without a environment:" |
253 | << con->constructor->name_hint; |
254 | TypeData td = mod_->type_definitions.at(con->constructor->belong_to); |
255 | auto pc = GetRef<PatternConstructor>(con); |
256 | |
257 | // we can expect a certain number of arguments |
258 | Array<Type> unknown_args; |
259 | for (size_t i = 0; i < td->type_vars.size(); i++) { |
260 | unknown_args.push_back(IncompleteType(Kind::kType)); |
261 | } |
262 | |
263 | Type expected = TypeCall(con->constructor->belong_to, unknown_args); |
264 | Type unified = Unify(t, expected, pc->span); |
265 | |
266 | auto* tc = unified.as<TypeCallNode>(); |
267 | if (!tc) { |
268 | this->EmitFatal(Diagnostic::Error(pc->span) << "Expected a type call, got " << unified); |
269 | } |
270 | |
271 | if (td->header != tc->func) { |
272 | this->EmitFatal(Diagnostic::Error(pc->span) << "ADT headers must match, but we have " |
273 | << td->header << " and " << tc->func); |
274 | } |
275 | |
276 | if (td->type_vars.size() != tc->args.size()) { |
277 | this->EmitFatal(Diagnostic::Error(pc->span) |
278 | << "The number of type args must match" |
279 | << "the number of type vars in the type data: " << td->type_vars.size() |
280 | << " != " << tc->args.size()); |
281 | } |
282 | std::unordered_map<TypeVar, Type, ObjectPtrHash, ObjectPtrEqual> type_var_map_; |
283 | for (size_t i = 0; i < td->type_vars.size(); ++i) { |
284 | type_var_map_[td->type_vars[i]] = tc->args[i]; |
285 | } |
286 | |
287 | if (con->constructor->inputs.size() != con->patterns.size()) { |
288 | this->EmitFatal(Diagnostic::Error(pc->span) << "Not enough inputs for the constructor; " |
289 | << "expected " << con->constructor->inputs.size() |
290 | << ", got " << con->patterns.size()); |
291 | } |
292 | |
293 | for (size_t i = 0; i < con->constructor->inputs.size(); ++i) { |
294 | VisitPattern(con->patterns[i], Bind(con->constructor->inputs[i], type_var_map_)); |
295 | } |
296 | } |
297 | |
298 | void VisitPattern_(const PatternTupleNode* tup, const Type& t) { |
299 | auto pt = GetRef<PatternTuple>(tup); |
300 | |
301 | // we can expect a certain number of arguments |
302 | Array<Type> unknown_args; |
303 | for (size_t i = 0; i < tup->patterns.size(); i++) { |
304 | unknown_args.push_back(IncompleteType(Kind::kType)); |
305 | } |
306 | |
307 | Type expected = TupleType(unknown_args); |
308 | Type unified = Unify(t, expected, tup->span); |
309 | |
310 | auto* tt = unified.as<TupleTypeNode>(); |
311 | if (!tt) { |
312 | this->EmitFatal(Diagnostic::Error(pt->span) << "Expected a tuple type, got " << unified); |
313 | } |
314 | ICHECK(tup->patterns.size() == tt->fields.size()) << "not enough pattern" ; |
315 | for (size_t i = 0; i < tup->patterns.size(); ++i) { |
316 | VisitPattern(tup->patterns[i], tt->fields[i]); |
317 | } |
318 | } |
319 | |
320 | void VisitPattern_(const PatternVarNode* pv, const Type& t) { |
321 | Type vt = GetType(pv->var); |
322 | Unify(vt, t, pv->span); |
323 | } |
324 | |
325 | void VisitPattern_(const PatternWildcardNode* wc, const Type& t) {} |
326 | |
327 | Type VisitExpr_(const MatchNode* op) final { |
328 | Type dtype = GetType(op->data); |
329 | for (const auto& c : op->clauses) { |
330 | VisitPattern(c->lhs, dtype); |
331 | } |
332 | Type rtype = IncompleteType(Kind::kType); |
333 | for (const auto& c : op->clauses) { |
334 | rtype = this->Unify(rtype, GetType(c->rhs), op->span); |
335 | } |
336 | |
337 | if (op->complete) { |
338 | // check completness |
339 | Match match = GetRef<Match>(op); |
340 | Array<Pattern> unmatched_cases = UnmatchedCases(match, this->mod_); |
341 | if (unmatched_cases.size() != 0) { |
342 | ErrorBuilder ss; |
343 | auto err = Diagnostic::Error(match->span); |
344 | err << "match expression does not handle the following cases: " ; |
345 | int i = 0; |
346 | for (auto cs : unmatched_cases) { |
347 | err << "case " << i++ << ": \n" << PrettyPrint(cs); |
348 | } |
349 | this->EmitFatal(err); |
350 | } |
351 | } |
352 | |
353 | return rtype; |
354 | } |
355 | |
356 | Type VisitExpr_(const OpNode* op) final { return op->op_type; } |
357 | |
358 | Type VisitExpr_(const LetNode* let) final { |
359 | auto pre_visit = [this](const LetNode* op) { |
360 | // if the definition is a function literal, permit recursion |
361 | bool is_functional_literal = op->value.as<FunctionNode>() != nullptr; |
362 | Type let_type = IncompleteType(Kind::kType); |
363 | |
364 | if (is_functional_literal) { |
365 | let_type = this->GetType(op->var); |
366 | this->type_map_[op->var].checked_type = let_type; |
367 | } |
368 | |
369 | if (op->var->type_annotation.defined()) { |
370 | let_type = this->Unify(let_type, op->var->type_annotation, op->span); |
371 | } |
372 | |
373 | Type vtype = this->GetType(op->value); |
374 | let_type = this->Unify(let_type, vtype, op->span); |
375 | |
376 | ICHECK(is_functional_literal || !this->type_map_.count(op->var)); |
377 | // NOTE: no scoping is necessary because var are unique in program |
378 | this->type_map_[op->var].checked_type = let_type; |
379 | }; |
380 | auto post_visit = [this](const LetNode* op) { |
381 | Expr expr = GetRef<Expr>(op); |
382 | this->memo_[expr] = this->GetType(op->body); |
383 | this->type_map_[expr].checked_type = this->memo_[expr]; |
384 | }; |
385 | ExpandANormalForm(let, pre_visit, post_visit); |
386 | return memo_[GetRef<Expr>(let)]; |
387 | } |
388 | |
389 | Type VisitExpr_(const IfNode* ite) final { |
390 | // Ensure the type of the guard is of Tensor[Bool, ()], |
391 | // that is a rank-0 boolean tensor. |
392 | Type cond_type = this->GetType(ite->cond); |
393 | this->Unify(cond_type, TensorType::Scalar(tvm::DataType::Bool()), ite->cond->span); |
394 | Type checked_true = this->GetType(ite->true_branch); |
395 | Type checked_false = this->GetType(ite->false_branch); |
396 | return this->Unify(checked_true, checked_false, ite->span); |
397 | } |
398 | |
399 | // This code is special-cased for primitive operators, |
400 | // which are registered in the style defined in src/relay/op/*. |
401 | // |
402 | // The result will be the return type of the operator. |
403 | Type PrimitiveCall(const FuncTypeNode* op, Array<Type> arg_types, const Attrs& attrs, |
404 | const Span& span) { |
405 | if (op->type_params.size() != arg_types.size() + 1) return Type(); |
406 | if (op->type_constraints.size() != 1) return Type(); |
407 | const TypeRelationNode* rel = op->type_constraints[0].as<TypeRelationNode>(); |
408 | if (rel == nullptr) return Type(); |
409 | // validate if the type parameter matches up |
410 | for (size_t i = 0; i < op->type_params.size(); ++i) { |
411 | if (!op->type_params[i].same_as(rel->args[i])) return Type(); |
412 | } |
413 | Type rtype = IncompleteType(Kind::kType); |
414 | arg_types.push_back(rtype); |
415 | // we can do simple replacement here |
416 | solver_.AddConstraint(TypeRelation(rel->func, arg_types, arg_types.size() - 1, attrs), span); |
417 | return rtype; |
418 | } |
419 | |
420 | // substitute the type args in the function type |
421 | FuncType InstantiateFuncType(const FuncTypeNode* fn_ty, const Array<Type>& ty_args) { |
422 | tvm::Map<TypeVar, Type> subst_map; |
423 | |
424 | // Build a subsitituion map up from the function type and type arguments. |
425 | // Eventually allow the type vars to be passed in. |
426 | ICHECK(fn_ty->type_params.size() == ty_args.size()) |
427 | << "number of type parameters does not match expected" ; |
428 | for (size_t i = 0; i < ty_args.size(); ++i) { |
429 | subst_map.Set(fn_ty->type_params[i], ty_args[i]); |
430 | } |
431 | |
432 | Type ret_type = fn_ty->ret_type; |
433 | |
434 | // If the function type is incomplete, place a new IncompleteType |
435 | // This relax the fn_ty to inputs -> Any |
436 | // The type checking can still pass when there are additional constraints on the type |
437 | // This is a temporary work around to check recursive functions whose |
438 | // return type is not yet known. |
439 | if (!ret_type.defined()) { |
440 | ret_type = IncompleteType(Kind::kType); |
441 | } |
442 | |
443 | Type inst_ty = FuncType(fn_ty->arg_types, ret_type, {}, fn_ty->type_constraints); |
444 | inst_ty = Bind(inst_ty, subst_map); |
445 | return Downcast<FuncType>(inst_ty); |
446 | } |
447 | |
448 | // instantiates starting from incompletes |
449 | FuncType InstantiateFuncType(const FuncTypeNode* fn_ty) { |
450 | if (fn_ty->type_params.size() == 0) { |
451 | return GetRef<FuncType>(fn_ty); |
452 | } |
453 | |
454 | Array<Type> type_args; |
455 | for (size_t i = 0; i < fn_ty->type_params.size(); i++) { |
456 | type_args.push_back(IncompleteType(Kind::kType)); |
457 | } |
458 | return InstantiateFuncType(fn_ty, type_args); |
459 | } |
460 | |
461 | void AddTypeArgs(const Expr& expr, Array<Type> type_args) { |
462 | auto type_info = type_map_.find(expr); |
463 | if (type_info == type_map_.end()) { |
464 | type_map_.insert({expr, ResolvedTypeInfo(Type(), type_args)}); |
465 | } else { |
466 | ICHECK(!type_info->second.type_args.defined()); |
467 | type_info->second.type_args = type_args; |
468 | } |
469 | } |
470 | |
471 | // Handle general call node. |
472 | Type GeneralCall(const CallNode* call, Array<Type> arg_types) { |
473 | Type ftype = GetType(call->op); |
474 | auto* fn_ty_node = ftype.as<FuncTypeNode>(); |
475 | auto* inc_ty_node = ftype.as<IncompleteTypeNode>(); |
476 | |
477 | if (fn_ty_node == nullptr && inc_ty_node == nullptr) { |
478 | this->EmitFatal(Diagnostic::Error(call->span) |
479 | << "only expressions with function types can be called, found " << ftype); |
480 | } |
481 | |
482 | // incomplete type => it must be a function taking the arg types |
483 | // with an unknown return type |
484 | if (inc_ty_node != nullptr) { |
485 | Type ret_type = IncompleteType(Kind::kType); |
486 | Type func_type = FuncType(arg_types, ret_type, {}, {}); |
487 | Type unified = this->Unify(ftype, func_type, call->op->span); |
488 | fn_ty_node = unified.as<FuncTypeNode>(); |
489 | } |
490 | |
491 | Array<Type> type_args = call->type_args; |
492 | if (type_args.size() > fn_ty_node->type_params.size()) { |
493 | this->EmitFatal(Diagnostic::Error(call->span) |
494 | << "Incorrect number of type args in " << call->span << ": " |
495 | << "Expected " << fn_ty_node->type_params.size() << " but got " |
496 | << type_args.size() << " for call:\n" |
497 | << PrettyPrint(GetRef<Call>(call))); |
498 | } |
499 | for (size_t i = type_args.size(); i < fn_ty_node->type_params.size(); i++) { |
500 | type_args.push_back(IncompleteType(TypeKind::kType)); |
501 | } |
502 | |
503 | FuncType fn_ty = InstantiateFuncType(fn_ty_node, type_args); |
504 | |
505 | AddTypeArgs(GetRef<Call>(call), type_args); |
506 | |
507 | size_t type_arity = fn_ty->arg_types.size(); |
508 | size_t number_of_args = arg_types.size(); |
509 | bool is_variable = false; |
510 | |
511 | if (const OpNode* opnode = call->op.as<OpNode>()) { |
512 | if (opnode->num_inputs == -1) { |
513 | is_variable = true; |
514 | } |
515 | } |
516 | |
517 | if ((type_arity < number_of_args) && !is_variable) { |
518 | this->EmitFatal(Diagnostic::Error(call->span) |
519 | << "the function is provided too many arguments " |
520 | << "expected " << type_arity << ", found " << number_of_args); |
521 | } else if (type_arity > number_of_args) { |
522 | this->EmitFatal(Diagnostic::Error(call->span) |
523 | << "the function is provided too few arguments " |
524 | << "expected " << type_arity << ", found " << number_of_args); |
525 | } |
526 | |
527 | Array<Type> unified_arg_types; |
528 | if (!is_variable) { |
529 | for (size_t i = 0; i < fn_ty->arg_types.size(); i++) { |
530 | this->Unify(fn_ty->arg_types[i], arg_types[i], call->span, true, false); |
531 | } |
532 | } else { |
533 | for (size_t i = 0; i < number_of_args; i++) { |
534 | if (i < fn_ty->arg_types.size()) { |
535 | unified_arg_types.push_back( |
536 | this->Unify(fn_ty->arg_types[i], arg_types[i], call->span, false, false)); |
537 | } else { |
538 | unified_arg_types.push_back(arg_types[i]); |
539 | } |
540 | } |
541 | unified_arg_types.push_back(fn_ty->ret_type); |
542 | } |
543 | for (auto cs : fn_ty->type_constraints) { |
544 | if (const auto* tr = cs.as<TypeRelationNode>()) { |
545 | if (!is_variable) { |
546 | solver_.AddConstraint(TypeRelation(tr->func, tr->args, tr->num_inputs, call->attrs), |
547 | call->span); |
548 | } else { |
549 | solver_.AddConstraint( |
550 | TypeRelation(tr->func, unified_arg_types, number_of_args, call->attrs), call->span); |
551 | } |
552 | } else { |
553 | solver_.AddConstraint(cs, call->span); |
554 | } |
555 | } |
556 | |
557 | return fn_ty->ret_type; |
558 | } |
559 | |
560 | Type VisitExpr_(const CallNode* call) final { |
561 | Array<Type> arg_types; |
562 | for (Expr arg : call->args) { |
563 | arg_types.push_back(GetType(arg)); |
564 | } |
565 | |
566 | if (const OpNode* opnode = call->op.as<OpNode>()) { |
567 | Type rtype = |
568 | PrimitiveCall(opnode->op_type.as<FuncTypeNode>(), arg_types, call->attrs, call->span); |
569 | |
570 | if (rtype.defined()) { |
571 | AddTypeArgs(GetRef<Call>(call), arg_types); |
572 | return rtype; |
573 | } |
574 | } |
575 | |
576 | solver_.Solve(); |
577 | return GeneralCall(call, arg_types); |
578 | } |
579 | |
580 | Type VisitExpr_(const FunctionNode* f) final { |
581 | solver_.Solve(); |
582 | Array<Type> arg_types; |
583 | for (auto param : f->params) { |
584 | arg_types.push_back(GetType(param)); |
585 | } |
586 | Type rtype = GetType(f->body); |
587 | if (auto* ft = rtype.as<FuncTypeNode>()) { |
588 | rtype = InstantiateFuncType(ft); |
589 | } |
590 | if (f->ret_type.defined()) { |
591 | rtype = this->Unify(f->ret_type, rtype, GetRef<Function>(f)->span); |
592 | } |
593 | ICHECK(rtype.defined()); |
594 | auto ret = FuncType(arg_types, rtype, f->type_params, {}); |
595 | return solver_.Resolve(ret); |
596 | } |
597 | |
598 | Type VisitExpr_(const RefCreateNode* op) final { return RelayRefType(GetType(op->value)); } |
599 | |
600 | Type VisitExpr_(const RefReadNode* op) final { |
601 | Type it = IncompleteType(Kind::kType); |
602 | this->Unify(GetType(op->ref), RelayRefType(it), op->span); |
603 | return it; |
604 | } |
605 | |
606 | Type VisitExpr_(const RefWriteNode* op) final { |
607 | Type it = IncompleteType(Kind::kType); |
608 | this->Unify(GetType(op->ref), RelayRefType(it), op->span); |
609 | this->Unify(GetType(op->value), it, op->span); |
610 | return TupleType::Empty(); |
611 | } |
612 | |
613 | Type VisitExpr_(const ConstructorNode* c) final { |
614 | ICHECK(mod_.defined()) << "Cannot do type inference without a environment:" << c->name_hint; |
615 | TypeData td = mod_->LookupTypeDef(c->belong_to); |
616 | std::vector<Type> types; |
617 | for (const auto& t : td->type_vars) { |
618 | types.push_back(t); |
619 | } |
620 | return FuncType(c->inputs, TypeCall(c->belong_to, types), td->type_vars, {}); |
621 | } |
622 | |
623 | void Solve() { solver_.Solve(); } |
624 | }; |
625 | |
626 | class TypeInferencer::Resolver : public MixedModeMutator, PatternMutator { |
627 | public: |
628 | Resolver(const std::unordered_map<Expr, ResolvedTypeInfo, ObjectPtrHash, ObjectPtrEqual>& tmap, |
629 | TypeSolver* solver) |
630 | : tmap_(tmap), solver_(solver) {} |
631 | |
632 | using MixedModeMutator::VisitExpr_; |
633 | |
634 | Expr VisitExpr_(const VarNode* op) final { return VisitVar(GetRef<Var>(op)); } |
635 | |
636 | Expr VisitExpr_(const ConstantNode* op) final { return AttachCheckedType(op); } |
637 | |
638 | Expr VisitExpr_(const GlobalVarNode* op) final { return GetRef<GlobalVar>(op); } |
639 | |
640 | Expr VisitExpr_(const OpNode* op) final { return ExprMutator::VisitExpr_(op); } |
641 | |
642 | Expr Rewrite_(const TupleNode* op, const Expr& post) final { return AttachCheckedType(op, post); } |
643 | |
644 | Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) final { |
645 | return AttachCheckedType(op, post); |
646 | } |
647 | |
648 | Expr VisitExpr_(const FunctionNode* op) final { return AttachCheckedType(op); } |
649 | |
650 | Expr Rewrite_(const CallNode* op, const Expr& post) final { return AttachCheckedType(op, post); } |
651 | |
652 | Expr VisitExpr_(const LetNode* op) final { |
653 | auto pre_visit = [this](const LetNode* op) { |
654 | this->VisitExpr(op->var); |
655 | this->VisitExpr(op->value); |
656 | }; |
657 | auto post_visit = [this](const LetNode* op) { |
658 | Expr expr = GetRef<Expr>(op); |
659 | Var var = Downcast<Var>(this->VisitExpr(op->var)); |
660 | Expr value = this->VisitExpr(op->value); |
661 | Expr body = this->VisitExpr(op->body); |
662 | this->memo_[expr] = this->AttachCheckedType(op, Let(var, value, body)); |
663 | }; |
664 | ExpandANormalForm(op, pre_visit, post_visit); |
665 | return memo_[GetRef<Expr>(op)]; |
666 | } |
667 | |
668 | Expr VisitExpr_(const IfNode* op) final { return AttachCheckedType(op); } |
669 | |
670 | Expr VisitExpr_(const RefCreateNode* op) final { return AttachCheckedType(op); } |
671 | |
672 | Expr VisitExpr_(const RefReadNode* op) final { return AttachCheckedType(op); } |
673 | |
674 | Expr VisitExpr_(const RefWriteNode* op) final { return AttachCheckedType(op); } |
675 | |
676 | Expr VisitExpr_(const ConstructorNode* op) final { return AttachCheckedType(op); } |
677 | |
678 | Expr VisitExpr_(const MatchNode* op) final { return AttachCheckedType(op); } |
679 | |
680 | Pattern VisitPattern(const Pattern& p) final { return PatternMutator::VisitPattern(p); } |
681 | |
682 | Var VisitVar(const Var& v) final { |
683 | if (vmap_.count(v) == 0) { |
684 | vmap_[v] = GetRef<Var>(AttachCheckedType(v.as<VarNode>()).as<VarNode>()); |
685 | } |
686 | return vmap_.at(v); |
687 | } |
688 | |
689 | // attach checked type to the mutated node. |
690 | template <typename T> |
691 | Expr AttachCheckedType(const T* op, const Expr& post = Expr()) { |
692 | auto it = tmap_.find(GetRef<Expr>(op)); |
693 | ICHECK(it != tmap_.end()); |
694 | Type checked_type = solver_->Resolve(it->second.checked_type); |
695 | |
696 | if (checked_type.as<IncompleteTypeNode>() != nullptr) { |
697 | this->solver_->Emit( |
698 | Diagnostic::Error(op->span) |
699 | << "The type inference pass was unable to infer a type for this expression.\n" |
700 | << "This usually occurs when an operator call is under constrained in some way," |
701 | << " check other reported errors for hints of what may of happened." ); |
702 | } |
703 | |
704 | Expr new_e = post.defined() ? post : ExprMutator::VisitExpr_(op); |
705 | // new_call and new_var's code is only going to be valid for VarNode/CallNode. |
706 | // Compiler optimization will likely fold these away for other nodes. |
707 | CallNode* new_call = (std::is_base_of<CallNode, T>::value |
708 | ? const_cast<CallNode*>(static_cast<const CallNode*>(new_e.get())) |
709 | : nullptr); |
710 | VarNode* new_var = (std::is_base_of<VarNode, T>::value |
711 | ? const_cast<VarNode*>(static_cast<const VarNode*>(new_e.get())) |
712 | : nullptr); |
713 | FunctionNode* new_fn = |
714 | (std::is_base_of<FunctionNode, T>::value |
715 | ? const_cast<FunctionNode*>(static_cast<const FunctionNode*>(new_e.get())) |
716 | : nullptr); |
717 | |
718 | // check if we need update the new_e |
719 | bool need_update_type = !checked_type.same_as(new_e->checked_type_); |
720 | bool need_update_call = |
721 | (std::is_base_of<CallNode, T>::value && it->second.type_args.defined() && |
722 | !it->second.type_args.same_as(new_call->type_args)); |
723 | bool need_update_var = (std::is_base_of<VarNode, T>::value && update_missing_type_annotation_ && |
724 | !new_var->type_annotation.defined()); |
725 | |
726 | bool need_update_fn = (std::is_base_of<FunctionNode, T>::value && |
727 | update_missing_type_annotation_ && !new_fn->ret_type.defined()); |
728 | |
729 | if (!need_update_type && !need_update_var && !need_update_call && !need_update_fn) { |
730 | return new_e; |
731 | } |
732 | |
733 | if (!new_e.unique()) { |
734 | // Copy on write optimization |
735 | // If new_e is an old expression, |
736 | // we make a copy mutating an existing reference. |
737 | ObjectPtr<ExprNode> ptr = make_object<T>(*new_e.as<T>()); |
738 | new_e = Expr(ptr); |
739 | new_call = |
740 | (std::is_base_of<CallNode, T>::value ? static_cast<CallNode*>(ptr.get()) : nullptr); |
741 | new_var = (std::is_base_of<VarNode, T>::value ? static_cast<VarNode*>(ptr.get()) : nullptr); |
742 | new_fn = (std::is_base_of<FunctionNode, T>::value ? static_cast<FunctionNode*>(ptr.get()) |
743 | : nullptr); |
744 | } |
745 | |
746 | // attach the information. |
747 | if (need_update_type) { |
748 | new_e->checked_type_ = checked_type; |
749 | } |
750 | |
751 | if (need_update_call) { |
752 | new_call->type_args = it->second.type_args; |
753 | for (size_t i = 0; i < new_call->type_args.size(); i++) { |
754 | new_call->type_args.Set(i, solver_->Resolve(new_call->type_args[i])); |
755 | } |
756 | } |
757 | if (need_update_var) { |
758 | new_var->type_annotation = checked_type; |
759 | } |
760 | if (need_update_fn) { |
761 | auto* fn_type = checked_type.as<FuncTypeNode>(); |
762 | ICHECK(fn_type != nullptr); |
763 | new_fn->ret_type = fn_type->ret_type; |
764 | } |
765 | return new_e; |
766 | } |
767 | |
768 | Type VisitType(const Type& t) final { return solver_->Resolve(t); } |
769 | |
770 | private: |
771 | std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual> vmap_; |
772 | const std::unordered_map<Expr, ResolvedTypeInfo, ObjectPtrHash, ObjectPtrEqual>& tmap_; |
773 | TypeSolver* solver_; |
774 | // whether attach the checked type as type_annotation |
775 | // if original type anntation is missing. |
776 | bool update_missing_type_annotation_{true}; |
777 | }; |
778 | |
779 | Expr TypeInferencer::Infer(GlobalVar var, Function function) { |
780 | // Set the current function being type checked. |
781 | this->current_func_ = var; |
782 | |
783 | // Step 1: Populate the constraints. |
784 | GetType(function); |
785 | |
786 | // Step 2: Solve the constraints. |
787 | Solve(); |
788 | |
789 | // Step 3: Attach resolved types to checked_type field. |
790 | auto resolved_expr = Resolver(type_map_, &solver_).VisitExpr(function); |
791 | |
792 | if (!WellFormed(resolved_expr, this->diag_ctx)) { |
793 | this->diag_ctx.Emit(Diagnostic::Bug(function->span) |
794 | << "the type checked function is malformed, please report this" ); |
795 | } |
796 | |
797 | return resolved_expr; |
798 | } |
799 | |
800 | struct AllCheckTypePopulated : MixedModeVisitor { |
801 | using MixedModeVisitor::VisitExpr_; |
802 | void DispatchExprVisit(const Expr& e) { |
803 | if (e.as<OpNode>()) { |
804 | return; |
805 | } |
806 | if (e.as<GlobalVarNode>()) { |
807 | return; |
808 | } |
809 | if (e.as<ConstructorNode>()) { |
810 | return; |
811 | } |
812 | ICHECK(e->checked_type_.defined()) << "Expression: " << e; |
813 | return ExprVisitor::VisitExpr(e); |
814 | } |
815 | void VisitExpr_(const LetNode* op) final { |
816 | auto pre_visit = [this](const LetNode* op) { |
817 | this->VisitExpr(op->var); |
818 | this->VisitExpr(op->value); |
819 | }; |
820 | auto post_visit = [this](const LetNode* op) { |
821 | this->VisitExpr(op->body); |
822 | this->visit_counter_[op] += 1; |
823 | }; |
824 | ExpandANormalForm(op, pre_visit, post_visit); |
825 | } |
826 | }; |
827 | |
828 | void EnsureCheckedType(const Expr& e) { AllCheckTypePopulated().VisitExpr(e); } |
829 | |
830 | // TODO(@jroesch): Can we optimize this? |
831 | void AddGlobalTypes(IRModule mod) { |
832 | std::vector<std::pair<GlobalVar, Function>> updates; |
833 | for (const auto& it : mod->functions) { |
834 | // Currently we don't type check TIR. |
835 | // The inferencer will only check Relay functions |
836 | // the future plan is to have a unified type checker |
837 | // that works on TIR and Relay at the same time. |
838 | if (auto* func_node = it.second.as<FunctionNode>()) { |
839 | Function func = Function(make_object<FunctionNode>(*func_node)); |
840 | func->checked_type_ = func->func_type_annotation(); |
841 | updates.push_back({it.first, Downcast<Function>(func)}); |
842 | } |
843 | } |
844 | |
845 | for (const auto& pair : updates) { |
846 | mod->Add(pair.first, pair.second, true); |
847 | } |
848 | } |
849 | |
850 | /*! |
851 | * \brief Returns a possibly much smaller subgraph whose inner nodes have the same type. |
852 | * |
853 | * Returns the largest sub-graph who's inner nodes need types and leaves are vars standing in |
854 | * for already typed sub-expressions. This creates a graph whose inner nodes have the same |
855 | * type as the original graph and when running type inference, we can avoid copying and |
856 | * recursing through most of the expression graph when running type inference. Note, this assumes |
857 | * that current populated type information is correct! |
858 | * |
859 | * ExprMutator is sufficient over MixedModemutator since we will not recurse much. |
860 | */ |
861 | class : public ExprMutator { |
862 | Expr (const VarNode* op) { return Var(op->vid, op->type_annotation, op->span); } |
863 | Expr (const ConstantNode* op) { return Constant(op->data, op->span); } |
864 | Expr (const GlobalVarNode* op) { return GlobalVar(op->name_hint); } |
865 | Expr (const OpNode* op) { return Op(GetRef<Op>(op)); } |
866 | Expr (const TupleNode* op) { |
867 | return Tuple(GetAnalogousExpression(op->fields), op->span); |
868 | } |
869 | Expr (const FunctionNode* op) { |
870 | // Unfortunately our strategy of inserting variables as dummies would change the signature of |
871 | // existing function nodes so we have to copy all used functions always :/ |
872 | return Function(op->params, op->body, op->ret_type, op->type_params, op->attrs, op->span); |
873 | } |
874 | Expr (const CallNode* op) { |
875 | return Call(op->op, GetAnalogousExpression(op->args), op->attrs, op->type_args, op->span); |
876 | } |
877 | Expr (const LetNode* op) { |
878 | return Let(op->var, GetAnalogousExpression(op->value), GetAnalogousExpression(op->body), |
879 | op->span); |
880 | } |
881 | Expr (const IfNode* op) { |
882 | return If(GetAnalogousExpression(op->cond), GetAnalogousExpression(op->true_branch), |
883 | GetAnalogousExpression(op->false_branch), op->span); |
884 | } |
885 | Expr (const TupleGetItemNode* op) { |
886 | return TupleGetItem(GetAnalogousExpression(op->tuple), op->index, op->span); |
887 | } |
888 | Expr (const RefCreateNode* op) { |
889 | return RefCreate(GetAnalogousExpression(op->value), op->span); |
890 | } |
891 | Expr (const RefReadNode* op) { |
892 | return RefRead(GetAnalogousExpression(op->ref), op->span); |
893 | } |
894 | Expr (const RefWriteNode* op) { |
895 | return RefWrite(GetAnalogousExpression(op->ref), GetAnalogousExpression(op->value), op->span); |
896 | } |
897 | Expr (const ConstructorNode* op) { |
898 | return Constructor(op->name_hint, op->inputs, op->belong_to); |
899 | } |
900 | Expr (const MatchNode* op) { |
901 | return Match(GetAnalogousExpression(op->data), op->clauses, op->complete, op->span); |
902 | } |
903 | |
904 | private: |
905 | Expr (const Expr& expr) { |
906 | // Replace the expression with a potentially simpler expression of the same type |
907 | if (expr->checked_type_.defined()) { |
908 | // Since the expression already has a checked_type which we assume is correct we don't need |
909 | // full type inference to enter it. So stub it out with a dummy var of the same type. |
910 | return Var("dummy_var" , expr->checked_type(), expr->span); |
911 | } |
912 | |
913 | return VisitExpr(expr); |
914 | } |
915 | Array<Expr> (const Array<Expr>& fields) { |
916 | Array<Expr> new_fields; |
917 | for (Expr expr : fields) { |
918 | new_fields.push_back(GetAnalogousExpression(expr)); |
919 | } |
920 | return new_fields; |
921 | } |
922 | }; |
923 | |
924 | namespace transform { |
925 | |
926 | Type InferTypeLocal(const Expr& expr) { |
927 | /* |
928 | This type inference differs from InferType in that it uses existing type information |
929 | to avoid recursing over much of the graph, and it only examines the type of the input |
930 | node. This makes it faster if you need to run type inference iteratively throughout |
931 | a pass for example. |
932 | |
933 | However, it assumes any existing populated type inference is correct! If some populated |
934 | type inference is incorrect, an incorrect type may be returned or a type error will be |
935 | raised. If you know not all populated type fields are correct with the current graph, |
936 | you should use InferType() instead. |
937 | */ |
938 | SameTypedSubgraphExtractor ; |
939 | Expr sub_graph = subgraph_extractor(expr); |
940 | |
941 | Type result_type; |
942 | result_type = relay::InferType(sub_graph)->checked_type(); |
943 | |
944 | expr->checked_type_ = result_type; |
945 | return result_type; |
946 | } |
947 | |
948 | TVM_REGISTER_GLOBAL("relay._transform.InferTypeLocal" ).set_body_typed([](const Expr& expr) { |
949 | return InferTypeLocal(expr); |
950 | }); |
951 | |
952 | Pass InferType() { |
953 | auto pass_info = PassInfo(0, "InferType" , {}); |
954 | return tvm::transform::CreateModulePass( |
955 | [=](IRModule mod, const PassContext& pass_ctx) { |
956 | // Execute the pass function and return a new module. |
957 | IRModule updated_mod = mod->ShallowCopy(); |
958 | |
959 | pass_ctx->diag_ctx = DiagnosticContext::Default(updated_mod); |
960 | |
961 | // Add all the type annotations to the functions in the model. |
962 | AddGlobalTypes(mod); |
963 | |
964 | std::vector<std::pair<GlobalVar, Function>> updates; |
965 | for (const auto& it : updated_mod->functions) { |
966 | // Currently we don't type check TIR. |
967 | // |
968 | // The inferencer will only check Relay functions. |
969 | |
970 | // In the future we plan a unified type checker |
971 | // that works on TIR and Relay at the same time. |
972 | if (auto* func_node = it.second.as<FunctionNode>()) { |
973 | auto func = GetRef<Function>(func_node); |
974 | |
975 | // // If a function already has type information we can skip checking it. |
976 | // if (func->checked_type_.defined()) { |
977 | // continue; |
978 | // } |
979 | |
980 | // TODO(@jroesch): we should be able to move the type inferencer outside |
981 | // of this function but it seems to be more stateful then I expect. |
982 | auto inferencer = TypeInferencer(mod, pass_ctx->diag_ctx.value()); |
983 | auto updated_func = inferencer.Infer(it.first, func); |
984 | |
985 | pass_ctx->diag_ctx.value().Render(); |
986 | |
987 | // After we are done checking write the global type back |
988 | // into the global var. |
989 | it.first->checked_type_ = updated_func->checked_type(); |
990 | |
991 | if (!WellFormed(updated_func, pass_ctx->diag_ctx)) { |
992 | LOG(FATAL) << "The type checked intermediate representation is malformed" ; |
993 | } |
994 | |
995 | auto free_tvars = FreeTypeVars(updated_func, mod); |
996 | ICHECK(free_tvars.size() == 0) |
997 | << "Found unbound type variables in " << updated_func << ": " << free_tvars; |
998 | EnsureCheckedType(updated_func); |
999 | updates.push_back({it.first, Downcast<Function>(updated_func)}); |
1000 | } |
1001 | } |
1002 | |
1003 | for (const auto& pair : updates) { |
1004 | updated_mod->Add(pair.first, pair.second, true); |
1005 | } |
1006 | |
1007 | return updated_mod; |
1008 | }, |
1009 | 0, "InferType" , {}); |
1010 | } |
1011 | |
1012 | TVM_REGISTER_GLOBAL("relay._transform.InferType" ).set_body_typed([]() { return InferType(); }); |
1013 | |
1014 | } // namespace transform |
1015 | |
1016 | } // namespace relay |
1017 | } // namespace tvm |
1018 | |