1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file type_infer.cc
22 * \brief Relay type inference and checking.
23 *
24 * This file implements one of the most important passes to the
25 * Relay IR. In order to do many transformations and generate the
26 * most efficient code we need to obtain type information for the
27 * IR.
28 *
29 * Similar to previous computation graph based IRs, the Relay IR leaves
30 * type information implicit and computes types by performing program
31 * analysis.
32 *
33 * Given an expression `e` this pass infers a type `t` for
34 * the expression as well as simultaneously checking the property `e : t`
35 * (i.e., we can show e has type t).
36 *
37 * If we can not infer a type or there is a conflicting
38 * constraint it will emit errors.
39 */
40
41#include <tvm/ir/transform.h>
42#include <tvm/ir/type_functor.h>
43#include <tvm/relay/analysis.h>
44#include <tvm/relay/dataflow_matcher.h>
45#include <tvm/relay/expr_functor.h>
46#include <tvm/relay/pattern_functor.h>
47#include <tvm/relay/transform.h>
48
49#include "../analysis/type_solver.h"
50#include "pass_utils.h"
51
52namespace tvm {
53namespace relay {
54
55// Necessary deferred relation for TupleGetItem
56struct TupleGetItemAttrs : public tvm::AttrsNode<TupleGetItemAttrs> {
57 int index;
58
59 TVM_DECLARE_ATTRS(TupleGetItemAttrs, "relay.attrs.TupleGetItemAttrs") { TVM_ATTR_FIELD(index); }
60};
61
62bool TupleGetItemRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
63 const TypeReporter& reporter) {
64 ICHECK_EQ(types.size(), 2);
65 if (types[0].as<IncompleteTypeNode>()) return false;
66 const auto* data = types[0].as<TupleTypeNode>();
67 ICHECK(data != nullptr) << "TupleGetItem expect input type to be TupleType "
68 << " get " << types[0] << " instead";
69 const auto* param = attrs.as<TupleGetItemAttrs>();
70 ICHECK(param != nullptr);
71 ICHECK_GE(param->index, 0);
72 ICHECK_LT(param->index, data->fields.size());
73 reporter->Assign(types[1], data->fields[param->index]);
74 return true;
75}
76
77TVM_REGISTER_NODE_TYPE(TupleGetItemAttrs);
78TVM_REGISTER_GLOBAL("tvm.relay.type_relation.TupleGetItem").set_body_typed(TupleGetItemRel);
79
80struct ResolvedTypeInfo {
81 explicit ResolvedTypeInfo(Type checked_type, Array<Type> type_args)
82 : checked_type(checked_type), type_args(type_args) {}
83 ResolvedTypeInfo() {}
84
85 Type checked_type;
86 // Only allocated when the expression is a call.
87
88 Array<Type> type_args = Array<Type>(ObjectPtr<Object>(nullptr));
89};
90
91//
92// The inference algorithm can roughly be divided into three stages:
93// - Populate the constraints by visiting the expression (TypeInferencer.GetType)
94// - solver.AddConstraint and solver.Unify are called to populate the necessary constraints
95// - Solve the constraints (solver_.Solve)
96// - Recreate expression with the resolved checked_type (Resolver.VisitExpr)
97//
98class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
99 private PatternFunctor<void(const Pattern&, const Type&)> {
100 public:
101 // constructors
102
103 explicit TypeInferencer(IRModule mod, DiagnosticContext diag_ctx)
104 : mod_(mod), diag_ctx(diag_ctx), solver_(GlobalVar(), diag_ctx) {
105 ICHECK(mod.defined()) << "Module must not be null in the type inferencer.";
106 }
107
108 // Infer the types inside of a function.
109 Expr Infer(GlobalVar var, Function expr);
110
111 private:
112 // type resolver that maps back to type
113 class Resolver;
114 // internal environment
115 IRModule mod_;
116
117 // The current function being type checked.
118 GlobalVar current_func_;
119
120 /*! \brief The diagnostic context. */
121 DiagnosticContext diag_ctx;
122
123 // map from expression to checked type
124 // type inferencer will populate it up
125 std::unordered_map<Expr, ResolvedTypeInfo, ObjectPtrHash, ObjectPtrEqual> type_map_;
126
127 // The solver used by the inferencer.
128 TypeSolver solver_;
129 // relation function
130 TypeRelationFn tuple_getitem_rel_;
131 TypeRelationFn make_tuple_rel_;
132
133 /*! \brief Internal map used for memoization. */
134 std::unordered_map<Expr, Type, ObjectPtrHash, ObjectPtrEqual> memo_;
135
136 void VisitLeaf(const Expr& expr) {
137 if (!memo_.count(expr)) {
138 Type ret = this->DispatchVisitExpr(expr);
139 memo_[expr] = ret;
140 }
141 }
142
143 bool CheckVisited(const Expr& expr) {
144 if (memo_.count(expr)) {
145 return true;
146 } else {
147 return false;
148 }
149 }
150
151 Type DispatchVisitExpr(const Expr& expr) { return ExprFunctor::VisitExpr(expr); }
152
153 Type VisitExpr(const Expr& expr) final {
154 auto fcheck_visited = [this](const Expr& expr) { return this->CheckVisited(expr); };
155 auto fvisit_leaf = [this](const Expr& expr) { return this->VisitLeaf(expr); };
156 if (memo_.count(expr)) {
157 return memo_[expr];
158 } else {
159 ExpandDataflow(expr, fcheck_visited, fvisit_leaf);
160 return memo_[expr];
161 }
162 }
163
164 // Perform unification on two types and report the error at the expression
165 // or the span of the expression.
166 Type Unify(const Type& t1, const Type& t2, const Span& span, bool assign_lhs = true,
167 bool assign_rhs = true) {
168 try {
169 return solver_.Unify(t1, t2, span, assign_lhs, assign_rhs);
170 } catch (const Error& e) {
171 this->EmitFatal(Diagnostic::Error(span)
172 << "Error unifying `" << t1 << "` and `" << t2 << "`: " << e.what());
173 return Type();
174 }
175 }
176
177 // Lazily get type for expr
178 // expression, we will populate it now, and return the result.
179 Type GetType(const Expr& expr) {
180 auto it = type_map_.find(expr);
181 if (it != type_map_.end() && it->second.checked_type.defined()) {
182 return it->second.checked_type;
183 }
184 Type ret = this->VisitExpr(expr);
185 ICHECK(ret.defined()) << "expression:" << std::endl << PrettyPrint(expr);
186 KindCheck(ret, mod_, this->diag_ctx);
187 ResolvedTypeInfo& rti = type_map_[expr];
188 rti.checked_type = ret;
189 return ret;
190 }
191
192 void EmitFatal(const Diagnostic& diag) { this->diag_ctx.EmitFatal(diag); }
193
194 // Visitor Logic
195 Type VisitExpr_(const VarNode* op) final {
196 if (op->type_annotation.defined()) {
197 return op->type_annotation;
198 } else {
199 return IncompleteType(Kind::kType);
200 }
201 }
202
203 Type VisitExpr_(const GlobalVarNode* op) final {
204 GlobalVar var = GetRef<GlobalVar>(op);
205 if (!mod_.defined()) {
206 this->EmitFatal(Diagnostic::Error(op->span) << "Cannot do type inference on global variables "
207 << "without a module");
208 }
209 if (mod_->ContainGlobalVar(var->name_hint)) {
210 BaseFunc func = mod_->Lookup(var->name_hint);
211
212 if (const auto* function_node = func.as<FunctionNode>()) {
213 VLOG(1) << "global var '" << op->name_hint << "' bound to Function";
214 return function_node->checked_type();
215 } else {
216 VLOG(1) << "global var '" << op->name_hint << "' bound to PrimFunc";
217 return op->checked_type_;
218 }
219 } else {
220 // TODO(mbs): extern function cleanup
221 // Assume the function is extern thus no longer in the IRModule.
222 VLOG(1) << "global var '" << op->name_hint << "' not in module";
223 return op->checked_type_;
224 }
225 }
226
227 Type VisitExpr_(const ConstantNode* op) final { return op->tensor_type(); }
228
229 Type VisitExpr_(const TupleNode* op) final {
230 Array<Type> types;
231 for (Expr field : op->fields) {
232 types.push_back(GetType(field));
233 }
234 return TupleType(types);
235 }
236
237 Type VisitExpr_(const TupleGetItemNode* op) final {
238 if (!tuple_getitem_rel_.defined()) {
239 tuple_getitem_rel_ =
240 Downcast<TypeRelationFn>(EnvFunc::Get("tvm.relay.type_relation.TupleGetItem"));
241 }
242 Type tuple_type = GetType(op->tuple);
243 Type rtype = IncompleteType(Kind::kType);
244 auto attrs = make_object<TupleGetItemAttrs>();
245 attrs->index = op->index;
246 solver_.AddConstraint(TypeRelation(tuple_getitem_rel_, {tuple_type, rtype}, 1, Attrs(attrs)),
247 op->span);
248 return rtype;
249 }
250
251 void VisitPattern_(const PatternConstructorNode* con, const Type& t) {
252 ICHECK(mod_.defined()) << "Cannot do type inference without a environment:"
253 << con->constructor->name_hint;
254 TypeData td = mod_->type_definitions.at(con->constructor->belong_to);
255 auto pc = GetRef<PatternConstructor>(con);
256
257 // we can expect a certain number of arguments
258 Array<Type> unknown_args;
259 for (size_t i = 0; i < td->type_vars.size(); i++) {
260 unknown_args.push_back(IncompleteType(Kind::kType));
261 }
262
263 Type expected = TypeCall(con->constructor->belong_to, unknown_args);
264 Type unified = Unify(t, expected, pc->span);
265
266 auto* tc = unified.as<TypeCallNode>();
267 if (!tc) {
268 this->EmitFatal(Diagnostic::Error(pc->span) << "Expected a type call, got " << unified);
269 }
270
271 if (td->header != tc->func) {
272 this->EmitFatal(Diagnostic::Error(pc->span) << "ADT headers must match, but we have "
273 << td->header << " and " << tc->func);
274 }
275
276 if (td->type_vars.size() != tc->args.size()) {
277 this->EmitFatal(Diagnostic::Error(pc->span)
278 << "The number of type args must match"
279 << "the number of type vars in the type data: " << td->type_vars.size()
280 << " != " << tc->args.size());
281 }
282 std::unordered_map<TypeVar, Type, ObjectPtrHash, ObjectPtrEqual> type_var_map_;
283 for (size_t i = 0; i < td->type_vars.size(); ++i) {
284 type_var_map_[td->type_vars[i]] = tc->args[i];
285 }
286
287 if (con->constructor->inputs.size() != con->patterns.size()) {
288 this->EmitFatal(Diagnostic::Error(pc->span) << "Not enough inputs for the constructor; "
289 << "expected " << con->constructor->inputs.size()
290 << ", got " << con->patterns.size());
291 }
292
293 for (size_t i = 0; i < con->constructor->inputs.size(); ++i) {
294 VisitPattern(con->patterns[i], Bind(con->constructor->inputs[i], type_var_map_));
295 }
296 }
297
298 void VisitPattern_(const PatternTupleNode* tup, const Type& t) {
299 auto pt = GetRef<PatternTuple>(tup);
300
301 // we can expect a certain number of arguments
302 Array<Type> unknown_args;
303 for (size_t i = 0; i < tup->patterns.size(); i++) {
304 unknown_args.push_back(IncompleteType(Kind::kType));
305 }
306
307 Type expected = TupleType(unknown_args);
308 Type unified = Unify(t, expected, tup->span);
309
310 auto* tt = unified.as<TupleTypeNode>();
311 if (!tt) {
312 this->EmitFatal(Diagnostic::Error(pt->span) << "Expected a tuple type, got " << unified);
313 }
314 ICHECK(tup->patterns.size() == tt->fields.size()) << "not enough pattern";
315 for (size_t i = 0; i < tup->patterns.size(); ++i) {
316 VisitPattern(tup->patterns[i], tt->fields[i]);
317 }
318 }
319
320 void VisitPattern_(const PatternVarNode* pv, const Type& t) {
321 Type vt = GetType(pv->var);
322 Unify(vt, t, pv->span);
323 }
324
325 void VisitPattern_(const PatternWildcardNode* wc, const Type& t) {}
326
327 Type VisitExpr_(const MatchNode* op) final {
328 Type dtype = GetType(op->data);
329 for (const auto& c : op->clauses) {
330 VisitPattern(c->lhs, dtype);
331 }
332 Type rtype = IncompleteType(Kind::kType);
333 for (const auto& c : op->clauses) {
334 rtype = this->Unify(rtype, GetType(c->rhs), op->span);
335 }
336
337 if (op->complete) {
338 // check completness
339 Match match = GetRef<Match>(op);
340 Array<Pattern> unmatched_cases = UnmatchedCases(match, this->mod_);
341 if (unmatched_cases.size() != 0) {
342 ErrorBuilder ss;
343 auto err = Diagnostic::Error(match->span);
344 err << "match expression does not handle the following cases: ";
345 int i = 0;
346 for (auto cs : unmatched_cases) {
347 err << "case " << i++ << ": \n" << PrettyPrint(cs);
348 }
349 this->EmitFatal(err);
350 }
351 }
352
353 return rtype;
354 }
355
356 Type VisitExpr_(const OpNode* op) final { return op->op_type; }
357
358 Type VisitExpr_(const LetNode* let) final {
359 auto pre_visit = [this](const LetNode* op) {
360 // if the definition is a function literal, permit recursion
361 bool is_functional_literal = op->value.as<FunctionNode>() != nullptr;
362 Type let_type = IncompleteType(Kind::kType);
363
364 if (is_functional_literal) {
365 let_type = this->GetType(op->var);
366 this->type_map_[op->var].checked_type = let_type;
367 }
368
369 if (op->var->type_annotation.defined()) {
370 let_type = this->Unify(let_type, op->var->type_annotation, op->span);
371 }
372
373 Type vtype = this->GetType(op->value);
374 let_type = this->Unify(let_type, vtype, op->span);
375
376 ICHECK(is_functional_literal || !this->type_map_.count(op->var));
377 // NOTE: no scoping is necessary because var are unique in program
378 this->type_map_[op->var].checked_type = let_type;
379 };
380 auto post_visit = [this](const LetNode* op) {
381 Expr expr = GetRef<Expr>(op);
382 this->memo_[expr] = this->GetType(op->body);
383 this->type_map_[expr].checked_type = this->memo_[expr];
384 };
385 ExpandANormalForm(let, pre_visit, post_visit);
386 return memo_[GetRef<Expr>(let)];
387 }
388
389 Type VisitExpr_(const IfNode* ite) final {
390 // Ensure the type of the guard is of Tensor[Bool, ()],
391 // that is a rank-0 boolean tensor.
392 Type cond_type = this->GetType(ite->cond);
393 this->Unify(cond_type, TensorType::Scalar(tvm::DataType::Bool()), ite->cond->span);
394 Type checked_true = this->GetType(ite->true_branch);
395 Type checked_false = this->GetType(ite->false_branch);
396 return this->Unify(checked_true, checked_false, ite->span);
397 }
398
399 // This code is special-cased for primitive operators,
400 // which are registered in the style defined in src/relay/op/*.
401 //
402 // The result will be the return type of the operator.
403 Type PrimitiveCall(const FuncTypeNode* op, Array<Type> arg_types, const Attrs& attrs,
404 const Span& span) {
405 if (op->type_params.size() != arg_types.size() + 1) return Type();
406 if (op->type_constraints.size() != 1) return Type();
407 const TypeRelationNode* rel = op->type_constraints[0].as<TypeRelationNode>();
408 if (rel == nullptr) return Type();
409 // validate if the type parameter matches up
410 for (size_t i = 0; i < op->type_params.size(); ++i) {
411 if (!op->type_params[i].same_as(rel->args[i])) return Type();
412 }
413 Type rtype = IncompleteType(Kind::kType);
414 arg_types.push_back(rtype);
415 // we can do simple replacement here
416 solver_.AddConstraint(TypeRelation(rel->func, arg_types, arg_types.size() - 1, attrs), span);
417 return rtype;
418 }
419
420 // substitute the type args in the function type
421 FuncType InstantiateFuncType(const FuncTypeNode* fn_ty, const Array<Type>& ty_args) {
422 tvm::Map<TypeVar, Type> subst_map;
423
424 // Build a subsitituion map up from the function type and type arguments.
425 // Eventually allow the type vars to be passed in.
426 ICHECK(fn_ty->type_params.size() == ty_args.size())
427 << "number of type parameters does not match expected";
428 for (size_t i = 0; i < ty_args.size(); ++i) {
429 subst_map.Set(fn_ty->type_params[i], ty_args[i]);
430 }
431
432 Type ret_type = fn_ty->ret_type;
433
434 // If the function type is incomplete, place a new IncompleteType
435 // This relax the fn_ty to inputs -> Any
436 // The type checking can still pass when there are additional constraints on the type
437 // This is a temporary work around to check recursive functions whose
438 // return type is not yet known.
439 if (!ret_type.defined()) {
440 ret_type = IncompleteType(Kind::kType);
441 }
442
443 Type inst_ty = FuncType(fn_ty->arg_types, ret_type, {}, fn_ty->type_constraints);
444 inst_ty = Bind(inst_ty, subst_map);
445 return Downcast<FuncType>(inst_ty);
446 }
447
448 // instantiates starting from incompletes
449 FuncType InstantiateFuncType(const FuncTypeNode* fn_ty) {
450 if (fn_ty->type_params.size() == 0) {
451 return GetRef<FuncType>(fn_ty);
452 }
453
454 Array<Type> type_args;
455 for (size_t i = 0; i < fn_ty->type_params.size(); i++) {
456 type_args.push_back(IncompleteType(Kind::kType));
457 }
458 return InstantiateFuncType(fn_ty, type_args);
459 }
460
461 void AddTypeArgs(const Expr& expr, Array<Type> type_args) {
462 auto type_info = type_map_.find(expr);
463 if (type_info == type_map_.end()) {
464 type_map_.insert({expr, ResolvedTypeInfo(Type(), type_args)});
465 } else {
466 ICHECK(!type_info->second.type_args.defined());
467 type_info->second.type_args = type_args;
468 }
469 }
470
471 // Handle general call node.
472 Type GeneralCall(const CallNode* call, Array<Type> arg_types) {
473 Type ftype = GetType(call->op);
474 auto* fn_ty_node = ftype.as<FuncTypeNode>();
475 auto* inc_ty_node = ftype.as<IncompleteTypeNode>();
476
477 if (fn_ty_node == nullptr && inc_ty_node == nullptr) {
478 this->EmitFatal(Diagnostic::Error(call->span)
479 << "only expressions with function types can be called, found " << ftype);
480 }
481
482 // incomplete type => it must be a function taking the arg types
483 // with an unknown return type
484 if (inc_ty_node != nullptr) {
485 Type ret_type = IncompleteType(Kind::kType);
486 Type func_type = FuncType(arg_types, ret_type, {}, {});
487 Type unified = this->Unify(ftype, func_type, call->op->span);
488 fn_ty_node = unified.as<FuncTypeNode>();
489 }
490
491 Array<Type> type_args = call->type_args;
492 if (type_args.size() > fn_ty_node->type_params.size()) {
493 this->EmitFatal(Diagnostic::Error(call->span)
494 << "Incorrect number of type args in " << call->span << ": "
495 << "Expected " << fn_ty_node->type_params.size() << " but got "
496 << type_args.size() << " for call:\n"
497 << PrettyPrint(GetRef<Call>(call)));
498 }
499 for (size_t i = type_args.size(); i < fn_ty_node->type_params.size(); i++) {
500 type_args.push_back(IncompleteType(TypeKind::kType));
501 }
502
503 FuncType fn_ty = InstantiateFuncType(fn_ty_node, type_args);
504
505 AddTypeArgs(GetRef<Call>(call), type_args);
506
507 size_t type_arity = fn_ty->arg_types.size();
508 size_t number_of_args = arg_types.size();
509 bool is_variable = false;
510
511 if (const OpNode* opnode = call->op.as<OpNode>()) {
512 if (opnode->num_inputs == -1) {
513 is_variable = true;
514 }
515 }
516
517 if ((type_arity < number_of_args) && !is_variable) {
518 this->EmitFatal(Diagnostic::Error(call->span)
519 << "the function is provided too many arguments "
520 << "expected " << type_arity << ", found " << number_of_args);
521 } else if (type_arity > number_of_args) {
522 this->EmitFatal(Diagnostic::Error(call->span)
523 << "the function is provided too few arguments "
524 << "expected " << type_arity << ", found " << number_of_args);
525 }
526
527 Array<Type> unified_arg_types;
528 if (!is_variable) {
529 for (size_t i = 0; i < fn_ty->arg_types.size(); i++) {
530 this->Unify(fn_ty->arg_types[i], arg_types[i], call->span, true, false);
531 }
532 } else {
533 for (size_t i = 0; i < number_of_args; i++) {
534 if (i < fn_ty->arg_types.size()) {
535 unified_arg_types.push_back(
536 this->Unify(fn_ty->arg_types[i], arg_types[i], call->span, false, false));
537 } else {
538 unified_arg_types.push_back(arg_types[i]);
539 }
540 }
541 unified_arg_types.push_back(fn_ty->ret_type);
542 }
543 for (auto cs : fn_ty->type_constraints) {
544 if (const auto* tr = cs.as<TypeRelationNode>()) {
545 if (!is_variable) {
546 solver_.AddConstraint(TypeRelation(tr->func, tr->args, tr->num_inputs, call->attrs),
547 call->span);
548 } else {
549 solver_.AddConstraint(
550 TypeRelation(tr->func, unified_arg_types, number_of_args, call->attrs), call->span);
551 }
552 } else {
553 solver_.AddConstraint(cs, call->span);
554 }
555 }
556
557 return fn_ty->ret_type;
558 }
559
560 Type VisitExpr_(const CallNode* call) final {
561 Array<Type> arg_types;
562 for (Expr arg : call->args) {
563 arg_types.push_back(GetType(arg));
564 }
565
566 if (const OpNode* opnode = call->op.as<OpNode>()) {
567 Type rtype =
568 PrimitiveCall(opnode->op_type.as<FuncTypeNode>(), arg_types, call->attrs, call->span);
569
570 if (rtype.defined()) {
571 AddTypeArgs(GetRef<Call>(call), arg_types);
572 return rtype;
573 }
574 }
575
576 solver_.Solve();
577 return GeneralCall(call, arg_types);
578 }
579
580 Type VisitExpr_(const FunctionNode* f) final {
581 solver_.Solve();
582 Array<Type> arg_types;
583 for (auto param : f->params) {
584 arg_types.push_back(GetType(param));
585 }
586 Type rtype = GetType(f->body);
587 if (auto* ft = rtype.as<FuncTypeNode>()) {
588 rtype = InstantiateFuncType(ft);
589 }
590 if (f->ret_type.defined()) {
591 rtype = this->Unify(f->ret_type, rtype, GetRef<Function>(f)->span);
592 }
593 ICHECK(rtype.defined());
594 auto ret = FuncType(arg_types, rtype, f->type_params, {});
595 return solver_.Resolve(ret);
596 }
597
598 Type VisitExpr_(const RefCreateNode* op) final { return RelayRefType(GetType(op->value)); }
599
600 Type VisitExpr_(const RefReadNode* op) final {
601 Type it = IncompleteType(Kind::kType);
602 this->Unify(GetType(op->ref), RelayRefType(it), op->span);
603 return it;
604 }
605
606 Type VisitExpr_(const RefWriteNode* op) final {
607 Type it = IncompleteType(Kind::kType);
608 this->Unify(GetType(op->ref), RelayRefType(it), op->span);
609 this->Unify(GetType(op->value), it, op->span);
610 return TupleType::Empty();
611 }
612
613 Type VisitExpr_(const ConstructorNode* c) final {
614 ICHECK(mod_.defined()) << "Cannot do type inference without a environment:" << c->name_hint;
615 TypeData td = mod_->LookupTypeDef(c->belong_to);
616 std::vector<Type> types;
617 for (const auto& t : td->type_vars) {
618 types.push_back(t);
619 }
620 return FuncType(c->inputs, TypeCall(c->belong_to, types), td->type_vars, {});
621 }
622
623 void Solve() { solver_.Solve(); }
624};
625
626class TypeInferencer::Resolver : public MixedModeMutator, PatternMutator {
627 public:
628 Resolver(const std::unordered_map<Expr, ResolvedTypeInfo, ObjectPtrHash, ObjectPtrEqual>& tmap,
629 TypeSolver* solver)
630 : tmap_(tmap), solver_(solver) {}
631
632 using MixedModeMutator::VisitExpr_;
633
634 Expr VisitExpr_(const VarNode* op) final { return VisitVar(GetRef<Var>(op)); }
635
636 Expr VisitExpr_(const ConstantNode* op) final { return AttachCheckedType(op); }
637
638 Expr VisitExpr_(const GlobalVarNode* op) final { return GetRef<GlobalVar>(op); }
639
640 Expr VisitExpr_(const OpNode* op) final { return ExprMutator::VisitExpr_(op); }
641
642 Expr Rewrite_(const TupleNode* op, const Expr& post) final { return AttachCheckedType(op, post); }
643
644 Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) final {
645 return AttachCheckedType(op, post);
646 }
647
648 Expr VisitExpr_(const FunctionNode* op) final { return AttachCheckedType(op); }
649
650 Expr Rewrite_(const CallNode* op, const Expr& post) final { return AttachCheckedType(op, post); }
651
652 Expr VisitExpr_(const LetNode* op) final {
653 auto pre_visit = [this](const LetNode* op) {
654 this->VisitExpr(op->var);
655 this->VisitExpr(op->value);
656 };
657 auto post_visit = [this](const LetNode* op) {
658 Expr expr = GetRef<Expr>(op);
659 Var var = Downcast<Var>(this->VisitExpr(op->var));
660 Expr value = this->VisitExpr(op->value);
661 Expr body = this->VisitExpr(op->body);
662 this->memo_[expr] = this->AttachCheckedType(op, Let(var, value, body));
663 };
664 ExpandANormalForm(op, pre_visit, post_visit);
665 return memo_[GetRef<Expr>(op)];
666 }
667
668 Expr VisitExpr_(const IfNode* op) final { return AttachCheckedType(op); }
669
670 Expr VisitExpr_(const RefCreateNode* op) final { return AttachCheckedType(op); }
671
672 Expr VisitExpr_(const RefReadNode* op) final { return AttachCheckedType(op); }
673
674 Expr VisitExpr_(const RefWriteNode* op) final { return AttachCheckedType(op); }
675
676 Expr VisitExpr_(const ConstructorNode* op) final { return AttachCheckedType(op); }
677
678 Expr VisitExpr_(const MatchNode* op) final { return AttachCheckedType(op); }
679
680 Pattern VisitPattern(const Pattern& p) final { return PatternMutator::VisitPattern(p); }
681
682 Var VisitVar(const Var& v) final {
683 if (vmap_.count(v) == 0) {
684 vmap_[v] = GetRef<Var>(AttachCheckedType(v.as<VarNode>()).as<VarNode>());
685 }
686 return vmap_.at(v);
687 }
688
689 // attach checked type to the mutated node.
690 template <typename T>
691 Expr AttachCheckedType(const T* op, const Expr& post = Expr()) {
692 auto it = tmap_.find(GetRef<Expr>(op));
693 ICHECK(it != tmap_.end());
694 Type checked_type = solver_->Resolve(it->second.checked_type);
695
696 if (checked_type.as<IncompleteTypeNode>() != nullptr) {
697 this->solver_->Emit(
698 Diagnostic::Error(op->span)
699 << "The type inference pass was unable to infer a type for this expression.\n"
700 << "This usually occurs when an operator call is under constrained in some way,"
701 << " check other reported errors for hints of what may of happened.");
702 }
703
704 Expr new_e = post.defined() ? post : ExprMutator::VisitExpr_(op);
705 // new_call and new_var's code is only going to be valid for VarNode/CallNode.
706 // Compiler optimization will likely fold these away for other nodes.
707 CallNode* new_call = (std::is_base_of<CallNode, T>::value
708 ? const_cast<CallNode*>(static_cast<const CallNode*>(new_e.get()))
709 : nullptr);
710 VarNode* new_var = (std::is_base_of<VarNode, T>::value
711 ? const_cast<VarNode*>(static_cast<const VarNode*>(new_e.get()))
712 : nullptr);
713 FunctionNode* new_fn =
714 (std::is_base_of<FunctionNode, T>::value
715 ? const_cast<FunctionNode*>(static_cast<const FunctionNode*>(new_e.get()))
716 : nullptr);
717
718 // check if we need update the new_e
719 bool need_update_type = !checked_type.same_as(new_e->checked_type_);
720 bool need_update_call =
721 (std::is_base_of<CallNode, T>::value && it->second.type_args.defined() &&
722 !it->second.type_args.same_as(new_call->type_args));
723 bool need_update_var = (std::is_base_of<VarNode, T>::value && update_missing_type_annotation_ &&
724 !new_var->type_annotation.defined());
725
726 bool need_update_fn = (std::is_base_of<FunctionNode, T>::value &&
727 update_missing_type_annotation_ && !new_fn->ret_type.defined());
728
729 if (!need_update_type && !need_update_var && !need_update_call && !need_update_fn) {
730 return new_e;
731 }
732
733 if (!new_e.unique()) {
734 // Copy on write optimization
735 // If new_e is an old expression,
736 // we make a copy mutating an existing reference.
737 ObjectPtr<ExprNode> ptr = make_object<T>(*new_e.as<T>());
738 new_e = Expr(ptr);
739 new_call =
740 (std::is_base_of<CallNode, T>::value ? static_cast<CallNode*>(ptr.get()) : nullptr);
741 new_var = (std::is_base_of<VarNode, T>::value ? static_cast<VarNode*>(ptr.get()) : nullptr);
742 new_fn = (std::is_base_of<FunctionNode, T>::value ? static_cast<FunctionNode*>(ptr.get())
743 : nullptr);
744 }
745
746 // attach the information.
747 if (need_update_type) {
748 new_e->checked_type_ = checked_type;
749 }
750
751 if (need_update_call) {
752 new_call->type_args = it->second.type_args;
753 for (size_t i = 0; i < new_call->type_args.size(); i++) {
754 new_call->type_args.Set(i, solver_->Resolve(new_call->type_args[i]));
755 }
756 }
757 if (need_update_var) {
758 new_var->type_annotation = checked_type;
759 }
760 if (need_update_fn) {
761 auto* fn_type = checked_type.as<FuncTypeNode>();
762 ICHECK(fn_type != nullptr);
763 new_fn->ret_type = fn_type->ret_type;
764 }
765 return new_e;
766 }
767
768 Type VisitType(const Type& t) final { return solver_->Resolve(t); }
769
770 private:
771 std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual> vmap_;
772 const std::unordered_map<Expr, ResolvedTypeInfo, ObjectPtrHash, ObjectPtrEqual>& tmap_;
773 TypeSolver* solver_;
774 // whether attach the checked type as type_annotation
775 // if original type anntation is missing.
776 bool update_missing_type_annotation_{true};
777};
778
779Expr TypeInferencer::Infer(GlobalVar var, Function function) {
780 // Set the current function being type checked.
781 this->current_func_ = var;
782
783 // Step 1: Populate the constraints.
784 GetType(function);
785
786 // Step 2: Solve the constraints.
787 Solve();
788
789 // Step 3: Attach resolved types to checked_type field.
790 auto resolved_expr = Resolver(type_map_, &solver_).VisitExpr(function);
791
792 if (!WellFormed(resolved_expr, this->diag_ctx)) {
793 this->diag_ctx.Emit(Diagnostic::Bug(function->span)
794 << "the type checked function is malformed, please report this");
795 }
796
797 return resolved_expr;
798}
799
800struct AllCheckTypePopulated : MixedModeVisitor {
801 using MixedModeVisitor::VisitExpr_;
802 void DispatchExprVisit(const Expr& e) {
803 if (e.as<OpNode>()) {
804 return;
805 }
806 if (e.as<GlobalVarNode>()) {
807 return;
808 }
809 if (e.as<ConstructorNode>()) {
810 return;
811 }
812 ICHECK(e->checked_type_.defined()) << "Expression: " << e;
813 return ExprVisitor::VisitExpr(e);
814 }
815 void VisitExpr_(const LetNode* op) final {
816 auto pre_visit = [this](const LetNode* op) {
817 this->VisitExpr(op->var);
818 this->VisitExpr(op->value);
819 };
820 auto post_visit = [this](const LetNode* op) {
821 this->VisitExpr(op->body);
822 this->visit_counter_[op] += 1;
823 };
824 ExpandANormalForm(op, pre_visit, post_visit);
825 }
826};
827
828void EnsureCheckedType(const Expr& e) { AllCheckTypePopulated().VisitExpr(e); }
829
830// TODO(@jroesch): Can we optimize this?
831void AddGlobalTypes(IRModule mod) {
832 std::vector<std::pair<GlobalVar, Function>> updates;
833 for (const auto& it : mod->functions) {
834 // Currently we don't type check TIR.
835 // The inferencer will only check Relay functions
836 // the future plan is to have a unified type checker
837 // that works on TIR and Relay at the same time.
838 if (auto* func_node = it.second.as<FunctionNode>()) {
839 Function func = Function(make_object<FunctionNode>(*func_node));
840 func->checked_type_ = func->func_type_annotation();
841 updates.push_back({it.first, Downcast<Function>(func)});
842 }
843 }
844
845 for (const auto& pair : updates) {
846 mod->Add(pair.first, pair.second, true);
847 }
848}
849
850/*!
851 * \brief Returns a possibly much smaller subgraph whose inner nodes have the same type.
852 *
853 * Returns the largest sub-graph who's inner nodes need types and leaves are vars standing in
854 * for already typed sub-expressions. This creates a graph whose inner nodes have the same
855 * type as the original graph and when running type inference, we can avoid copying and
856 * recursing through most of the expression graph when running type inference. Note, this assumes
857 * that current populated type information is correct!
858 *
859 * ExprMutator is sufficient over MixedModemutator since we will not recurse much.
860 */
861class SameTypedSubgraphExtractor : public ExprMutator {
862 Expr VisitExpr_(const VarNode* op) { return Var(op->vid, op->type_annotation, op->span); }
863 Expr VisitExpr_(const ConstantNode* op) { return Constant(op->data, op->span); }
864 Expr VisitExpr_(const GlobalVarNode* op) { return GlobalVar(op->name_hint); }
865 Expr VisitExpr_(const OpNode* op) { return Op(GetRef<Op>(op)); }
866 Expr VisitExpr_(const TupleNode* op) {
867 return Tuple(GetAnalogousExpression(op->fields), op->span);
868 }
869 Expr VisitExpr_(const FunctionNode* op) {
870 // Unfortunately our strategy of inserting variables as dummies would change the signature of
871 // existing function nodes so we have to copy all used functions always :/
872 return Function(op->params, op->body, op->ret_type, op->type_params, op->attrs, op->span);
873 }
874 Expr VisitExpr_(const CallNode* op) {
875 return Call(op->op, GetAnalogousExpression(op->args), op->attrs, op->type_args, op->span);
876 }
877 Expr VisitExpr_(const LetNode* op) {
878 return Let(op->var, GetAnalogousExpression(op->value), GetAnalogousExpression(op->body),
879 op->span);
880 }
881 Expr VisitExpr_(const IfNode* op) {
882 return If(GetAnalogousExpression(op->cond), GetAnalogousExpression(op->true_branch),
883 GetAnalogousExpression(op->false_branch), op->span);
884 }
885 Expr VisitExpr_(const TupleGetItemNode* op) {
886 return TupleGetItem(GetAnalogousExpression(op->tuple), op->index, op->span);
887 }
888 Expr VisitExpr_(const RefCreateNode* op) {
889 return RefCreate(GetAnalogousExpression(op->value), op->span);
890 }
891 Expr VisitExpr_(const RefReadNode* op) {
892 return RefRead(GetAnalogousExpression(op->ref), op->span);
893 }
894 Expr VisitExpr_(const RefWriteNode* op) {
895 return RefWrite(GetAnalogousExpression(op->ref), GetAnalogousExpression(op->value), op->span);
896 }
897 Expr VisitExpr_(const ConstructorNode* op) {
898 return Constructor(op->name_hint, op->inputs, op->belong_to);
899 }
900 Expr VisitExpr_(const MatchNode* op) {
901 return Match(GetAnalogousExpression(op->data), op->clauses, op->complete, op->span);
902 }
903
904 private:
905 Expr GetAnalogousExpression(const Expr& expr) {
906 // Replace the expression with a potentially simpler expression of the same type
907 if (expr->checked_type_.defined()) {
908 // Since the expression already has a checked_type which we assume is correct we don't need
909 // full type inference to enter it. So stub it out with a dummy var of the same type.
910 return Var("dummy_var", expr->checked_type(), expr->span);
911 }
912
913 return VisitExpr(expr);
914 }
915 Array<Expr> GetAnalogousExpression(const Array<Expr>& fields) {
916 Array<Expr> new_fields;
917 for (Expr expr : fields) {
918 new_fields.push_back(GetAnalogousExpression(expr));
919 }
920 return new_fields;
921 }
922};
923
924namespace transform {
925
926Type InferTypeLocal(const Expr& expr) {
927 /*
928 This type inference differs from InferType in that it uses existing type information
929 to avoid recursing over much of the graph, and it only examines the type of the input
930 node. This makes it faster if you need to run type inference iteratively throughout
931 a pass for example.
932
933 However, it assumes any existing populated type inference is correct! If some populated
934 type inference is incorrect, an incorrect type may be returned or a type error will be
935 raised. If you know not all populated type fields are correct with the current graph,
936 you should use InferType() instead.
937 */
938 SameTypedSubgraphExtractor subgraph_extractor;
939 Expr sub_graph = subgraph_extractor(expr);
940
941 Type result_type;
942 result_type = relay::InferType(sub_graph)->checked_type();
943
944 expr->checked_type_ = result_type;
945 return result_type;
946}
947
948TVM_REGISTER_GLOBAL("relay._transform.InferTypeLocal").set_body_typed([](const Expr& expr) {
949 return InferTypeLocal(expr);
950});
951
952Pass InferType() {
953 auto pass_info = PassInfo(0, "InferType", {});
954 return tvm::transform::CreateModulePass(
955 [=](IRModule mod, const PassContext& pass_ctx) {
956 // Execute the pass function and return a new module.
957 IRModule updated_mod = mod->ShallowCopy();
958
959 pass_ctx->diag_ctx = DiagnosticContext::Default(updated_mod);
960
961 // Add all the type annotations to the functions in the model.
962 AddGlobalTypes(mod);
963
964 std::vector<std::pair<GlobalVar, Function>> updates;
965 for (const auto& it : updated_mod->functions) {
966 // Currently we don't type check TIR.
967 //
968 // The inferencer will only check Relay functions.
969
970 // In the future we plan a unified type checker
971 // that works on TIR and Relay at the same time.
972 if (auto* func_node = it.second.as<FunctionNode>()) {
973 auto func = GetRef<Function>(func_node);
974
975 // // If a function already has type information we can skip checking it.
976 // if (func->checked_type_.defined()) {
977 // continue;
978 // }
979
980 // TODO(@jroesch): we should be able to move the type inferencer outside
981 // of this function but it seems to be more stateful then I expect.
982 auto inferencer = TypeInferencer(mod, pass_ctx->diag_ctx.value());
983 auto updated_func = inferencer.Infer(it.first, func);
984
985 pass_ctx->diag_ctx.value().Render();
986
987 // After we are done checking write the global type back
988 // into the global var.
989 it.first->checked_type_ = updated_func->checked_type();
990
991 if (!WellFormed(updated_func, pass_ctx->diag_ctx)) {
992 LOG(FATAL) << "The type checked intermediate representation is malformed";
993 }
994
995 auto free_tvars = FreeTypeVars(updated_func, mod);
996 ICHECK(free_tvars.size() == 0)
997 << "Found unbound type variables in " << updated_func << ": " << free_tvars;
998 EnsureCheckedType(updated_func);
999 updates.push_back({it.first, Downcast<Function>(updated_func)});
1000 }
1001 }
1002
1003 for (const auto& pair : updates) {
1004 updated_mod->Add(pair.first, pair.second, true);
1005 }
1006
1007 return updated_mod;
1008 },
1009 0, "InferType", {});
1010}
1011
1012TVM_REGISTER_GLOBAL("relay._transform.InferType").set_body_typed([]() { return InferType(); });
1013
1014} // namespace transform
1015
1016} // namespace relay
1017} // namespace tvm
1018