1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file type_solver.cc |
22 | * \brief Type solver implementations. |
23 | */ |
24 | #include "type_solver.h" |
25 | |
26 | #include <tvm/ir/type_functor.h> |
27 | #include <tvm/node/structural_equal.h> |
28 | #include <tvm/tir/op.h> |
29 | |
30 | #include <memory> |
31 | #include <string> |
32 | #include <tuple> |
33 | #include <utility> |
34 | |
35 | namespace tvm { |
36 | namespace relay { |
37 | |
38 | class TypeSolver::Reporter : public TypeReporterNode { |
39 | public: |
40 | explicit Reporter(TypeSolver* solver) : solver_(solver) {} |
41 | |
42 | void Assign(const Type& dst, const Type& src) final { solver_->Unify(dst, src, span); } |
43 | |
44 | bool Assert(const IndexExpr& cond) final { |
45 | if (const int64_t* pdiff = tir::as_const_int(cond)) { |
46 | return pdiff[0]; |
47 | } |
48 | return true; |
49 | } |
50 | |
51 | bool AssertEQ(const IndexExpr& lhs, const IndexExpr& rhs) final { |
52 | // early warning constant case. |
53 | IndexExpr diff = lhs - rhs; |
54 | if (const int64_t* pdiff = tir::as_const_int(diff)) { |
55 | return pdiff[0] == 0; |
56 | } |
57 | return true; |
58 | } |
59 | |
60 | TVM_DLL void SetSpan(const Span& span) final { this->span = span; } |
61 | |
62 | TVM_DLL Span GetSpan() final { return this->span; } |
63 | |
64 | TVM_DLL DiagnosticContext GetDiagCtx() final { return this->solver_->diag_ctx_; } |
65 | |
66 | // TVM_DLL void Emit(Diagnostic diagnostic) final { |
67 | // return this->solver_-> |
68 | // } |
69 | |
70 | TVM_DLL IRModule GetModule() final { return this->solver_->module_; } |
71 | |
72 | private: |
73 | /*! \brief The span to report unification errors at. */ |
74 | mutable Span span; |
75 | |
76 | TypeSolver* solver_; |
77 | }; |
78 | |
79 | class TypeSolver::OccursChecker : public TypeVisitor { |
80 | public: |
81 | explicit OccursChecker(TypeSolver* solver, TypeNode* var) |
82 | : solver_(solver), var_(var), found_(false) {} |
83 | |
84 | bool Check(const Type& t) { |
85 | VisitType(t); |
86 | return found_; |
87 | } |
88 | |
89 | void VisitType_(const IncompleteTypeNode* op) override { |
90 | IncompleteType t = GetRef<IncompleteType>(op); |
91 | TypeNode* node = solver_->GetTypeNode(t); |
92 | found_ = found_ || (var_->FindRoot() == node->FindRoot()); |
93 | } |
94 | |
95 | private: |
96 | TypeSolver* solver_; |
97 | TypeNode* var_; |
98 | bool found_; |
99 | }; |
100 | |
101 | class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> { |
102 | public: |
103 | explicit Unifier(TypeSolver* solver, const Span& span) : solver_(solver), span(span) {} |
104 | |
105 | Type Unify(const Type& lhs_type, const Type& rhs_type, bool assign_lhs = true, |
106 | bool assign_rhs = true) { |
107 | // Known limitation |
108 | // - handle shape pattern matching |
109 | TypeNode* lhs = solver_->GetTypeNode(lhs_type); |
110 | TypeNode* rhs = solver_->GetTypeNode(rhs_type); |
111 | |
112 | // do occur check so we don't create self-referencing structure |
113 | if (lhs->FindRoot() == rhs->FindRoot()) { |
114 | return lhs->resolved_type; |
115 | } |
116 | |
117 | if (lhs->resolved_type.as<IncompleteTypeNode>()) { |
118 | ICHECK(!OccursCheck(lhs, rhs->resolved_type)) |
119 | << "Incomplete type " << lhs->resolved_type << " occurs in " << rhs->resolved_type |
120 | << ", cannot unify" ; |
121 | |
122 | solver_->MergeFromTo(lhs, rhs); |
123 | return rhs->resolved_type; |
124 | } else if (rhs->resolved_type.as<IncompleteTypeNode>()) { |
125 | ICHECK(!OccursCheck(rhs, lhs->resolved_type)) |
126 | << "Incomplete type " << rhs->resolved_type << " occurs in " << lhs->resolved_type |
127 | << ", cannot unify" ; |
128 | solver_->MergeFromTo(rhs, lhs); |
129 | return lhs->resolved_type; |
130 | } else { |
131 | Type resolved = this->VisitType(rhs->resolved_type, lhs->resolved_type); |
132 | |
133 | if (!resolved.defined()) { |
134 | solver_->Emit(Diagnostic::Error(this->span) |
135 | << "The Relay type checker is unable to show the following types match.\n" |
136 | << "In particular " |
137 | << "`" << PrettyPrint(lhs->resolved_type) << "` does not match `" |
138 | << PrettyPrint(rhs->resolved_type) << "`" ); |
139 | return lhs->resolved_type; |
140 | } else { |
141 | TypeNode* top = solver_->GetTypeNode(resolved); |
142 | if (assign_lhs) solver_->MergeFromTo(lhs, top); |
143 | if (assign_rhs) solver_->MergeFromTo(rhs, top); |
144 | return resolved; |
145 | } |
146 | } |
147 | } |
148 | |
149 | // Checks whether lhs (taken to be a type var) occurs in t, meaning |
150 | // there is a recursive equality constraint, which should be rejected. |
151 | // N.b.: A tautology like ?a = ?a is okay and should be checked for |
152 | // *before* calling this method |
153 | // |
154 | // See: https://en.wikipedia.org/wiki/Occurs_check |
155 | bool OccursCheck(TypeNode* lhs, const Type& t) { |
156 | OccursChecker rc(solver_, lhs); |
157 | return rc.Check(t); |
158 | } |
159 | |
160 | // default: unify only if structural-equal |
161 | Type VisitTypeDefault_(const Object* op, const Type& tn) final { |
162 | ObjectRef nr = GetRef<ObjectRef>(op); |
163 | Type t1 = GetRef<Type>(nr.as<tvm::relay::TypeNode>()); |
164 | if (!tvm::StructuralEqual()(t1, tn)) { |
165 | return Type(nullptr); |
166 | } |
167 | return t1; |
168 | } |
169 | |
170 | IndexExpr GetShape(const IndexExpr& e) { |
171 | IndexExpr ex = e; |
172 | while (true) { |
173 | auto it = solver_->shape_uf_.find(ex); |
174 | if (it == solver_->shape_uf_.end()) { |
175 | return ex; |
176 | } else { |
177 | ex = (*it).second; |
178 | } |
179 | } |
180 | } |
181 | |
182 | IndexExpr UnifyDim(const IndexExpr& lhs, const IndexExpr& rhs) { |
183 | auto ulhs = GetShape(lhs); |
184 | auto urhs = GetShape(rhs); |
185 | |
186 | if (ulhs.same_as(urhs)) { |
187 | return ulhs; |
188 | } |
189 | if (ulhs.as<AnyNode>() || urhs.as<AnyNode>()) { |
190 | return Any(); |
191 | } |
192 | |
193 | auto left_index0 = ulhs.as<tvm::tir::VarNode>(); |
194 | auto right_index0 = urhs.as<tvm::IntImmNode>(); |
195 | if (left_index0 && right_index0) { |
196 | solver_->shape_uf_.Set(ulhs, urhs); |
197 | return urhs; |
198 | } |
199 | |
200 | auto left_index1 = ulhs.as<tvm::IntImmNode>(); |
201 | auto right_index1 = urhs.as<tvm::tir::VarNode>(); |
202 | if (left_index1 && right_index1) { |
203 | solver_->shape_uf_.Set(urhs, ulhs); |
204 | return ulhs; |
205 | } |
206 | |
207 | auto left_index2 = ulhs.as<tvm::IntImmNode>(); |
208 | auto right_index2 = urhs.as<tvm::IntImmNode>(); |
209 | if (left_index2 && right_index2 && left_index2->value == right_index2->value) { |
210 | return ulhs; |
211 | } |
212 | |
213 | return tvm::PrimExpr(); |
214 | } |
215 | |
216 | Type VisitType_(const TensorTypeNode* op, const Type& tn) final { |
217 | const auto* tt_node = tn.as<TensorTypeNode>(); |
218 | if (!tt_node) { |
219 | return Type(nullptr); |
220 | } |
221 | |
222 | auto tt1 = GetRef<TensorType>(op); |
223 | auto tt2 = GetRef<TensorType>(tt_node); |
224 | |
225 | if (tvm::StructuralEqual()(tt1, tt2)) { |
226 | return std::move(tt1); |
227 | } |
228 | |
229 | if (tt1->dtype != tt2->dtype) { |
230 | return Type(nullptr); |
231 | } |
232 | |
233 | tvm::Array<IndexExpr> shape; |
234 | if (tt1->shape.size() != tt2->shape.size()) { |
235 | this->solver_->Emit(Diagnostic::Error(this->span) |
236 | << "tensor type `" << PrettyPrint(tt1) << "` has " << tt1->shape.size() |
237 | << " dimensions, while `" << PrettyPrint(tt2) << "` has " |
238 | << tt2->shape.size() << " dimensions" ); |
239 | return Type(nullptr); |
240 | } |
241 | |
242 | std::vector<std::tuple<size_t, IndexExpr, IndexExpr>> mismatches; |
243 | |
244 | ICHECK_EQ(tt1->shape.size(), tt2->shape.size()); |
245 | for (size_t i = 0; i < tt1->shape.size(); i++) { |
246 | auto dim = UnifyDim(tt1->shape[i], tt2->shape[i]); |
247 | if (!dim.defined()) { |
248 | // NB: We push an arbitrary dimension here so we can continue error propagation. |
249 | shape.push_back(tt1->shape[i]); |
250 | tvm::PrimExpr shape1 = tt1->shape[i]; |
251 | tvm::PrimExpr shape2 = tt2->shape[i]; |
252 | std::tuple<int, IndexExpr, IndexExpr> tuple = std::make_tuple(i, shape1, shape2); |
253 | mismatches.push_back(tuple); |
254 | } else { |
255 | shape.push_back(dim); |
256 | } |
257 | } |
258 | |
259 | if (mismatches.size() != 0) { |
260 | auto err = Diagnostic::Error(this->span); |
261 | err << "The Relay type checker is unable to show the following types match:\n" |
262 | << " " << PrettyPrint(tt1) << "\n" |
263 | << " " << PrettyPrint(tt2) << "\n" ; |
264 | err << "In particular:\n" ; |
265 | for (auto mismatch : mismatches) { |
266 | err << " dimension " << std::get<0>(mismatch) << " conflicts: " << std::get<1>(mismatch) |
267 | << " does not match " << std::get<2>(mismatch) << "." ; |
268 | } |
269 | this->solver_->Emit(err); |
270 | return Type(nullptr); |
271 | } |
272 | |
273 | return TensorType(shape, tt1->dtype); |
274 | } |
275 | |
276 | Type VisitType_(const TupleTypeNode* op, const Type& tn) final { |
277 | const auto* ttn = tn.as<TupleTypeNode>(); |
278 | if (!ttn || op->fields.size() != ttn->fields.size()) { |
279 | return Type(nullptr); |
280 | } |
281 | |
282 | TupleType tt1 = GetRef<TupleType>(op); |
283 | TupleType tt2 = GetRef<TupleType>(ttn); |
284 | |
285 | std::vector<Type> new_fields; |
286 | for (size_t i = 0; i < tt1->fields.size(); i++) { |
287 | Type field = Unify(tt1->fields[i], tt2->fields[i]); |
288 | new_fields.push_back(field); |
289 | } |
290 | return TupleType(new_fields); |
291 | } |
292 | |
293 | Type VisitType_(const FuncTypeNode* op, const Type& tn) final { |
294 | const auto* ftn = tn.as<FuncTypeNode>(); |
295 | if (!ftn || op->arg_types.size() != ftn->arg_types.size() || |
296 | op->type_constraints.size() != ftn->type_constraints.size()) { |
297 | return Type(nullptr); |
298 | } |
299 | |
300 | // without loss of generality, suppose op->type_params.size() >= ftn->type_params.size(). |
301 | if (op->type_params.size() < ftn->type_params.size()) { |
302 | return VisitType_(ftn, GetRef<FuncType>(op)); |
303 | } |
304 | |
305 | // remap type vars so they match |
306 | Map<TypeVar, Type> subst_map; |
307 | tvm::Array<TypeVar> ft_type_params; |
308 | for (size_t i = 0; i < ftn->type_params.size(); ++i) { |
309 | subst_map.Set(op->type_params[i], ftn->type_params[i]); |
310 | ft_type_params.push_back(op->type_params[i]); |
311 | } |
312 | |
313 | for (size_t i = ftn->type_params.size(); i < op->type_params.size(); ++i) { |
314 | subst_map.Set(op->type_params[i], IncompleteType(kType)); |
315 | } |
316 | |
317 | FuncType ft = FuncType(op->arg_types, op->ret_type, ft_type_params, op->type_constraints); |
318 | auto ft1 = Downcast<FuncType>(Bind(ft, subst_map)); |
319 | auto ft2 = GetRef<FuncType>(ftn); |
320 | |
321 | Type ret_type = Unify(ft1->ret_type, ft2->ret_type); |
322 | |
323 | std::vector<Type> arg_types; |
324 | for (size_t i = 0; i < ft2->arg_types.size(); ++i) { |
325 | Type arg_type = Unify(ft1->arg_types[i], ft2->arg_types[i]); |
326 | arg_types.push_back(arg_type); |
327 | } |
328 | |
329 | std::vector<TypeConstraint> type_constraints; |
330 | for (size_t i = 0; i < ft1->type_constraints.size(); ++i) { |
331 | Type unified_constraint = Unify(ft1->type_constraints[i], ft2->type_constraints[i]); |
332 | const auto* tcn = unified_constraint.as<TypeConstraintNode>(); |
333 | ICHECK(tcn) << "Two type constraints unified into a non-constraint?" |
334 | << ft1->type_constraints[i] << " and " << ft2->type_constraints[i]; |
335 | type_constraints.push_back(GetRef<TypeConstraint>(tcn)); |
336 | } |
337 | |
338 | return FuncType(arg_types, ret_type, ft2->type_params, type_constraints); |
339 | } |
340 | |
341 | Type VisitType_(const RelayRefTypeNode* op, const Type& tn) final { |
342 | const auto* rtn = tn.as<RelayRefTypeNode>(); |
343 | if (!rtn) { |
344 | return Type(nullptr); |
345 | } |
346 | return RelayRefType(Unify(op->value, rtn->value)); |
347 | } |
348 | |
349 | Type VisitType_(const TypeCallNode* op, const Type& tn) override { |
350 | const auto* tcn = tn.as<TypeCallNode>(); |
351 | if (!tcn || tcn->args.size() != op->args.size()) { |
352 | return Type(); |
353 | } |
354 | |
355 | Type func = Unify(op->func, tcn->func); |
356 | tvm::Array<Type> args; |
357 | for (size_t i = 0; i < op->args.size(); i++) { |
358 | args.push_back(Unify(op->args[i], tcn->args[i])); |
359 | } |
360 | return TypeCall(func, args); |
361 | } |
362 | |
363 | private: |
364 | TypeSolver* solver_; |
365 | Span span; |
366 | }; |
367 | |
368 | class TypeSolver::Resolver : public TypeMutator { |
369 | public: |
370 | explicit Resolver(TypeSolver* solver) : solver_(solver) {} |
371 | |
372 | Type Resolve(const Type& t) { |
373 | if (!t.defined()) { |
374 | return t; |
375 | } |
376 | return VisitType(t); |
377 | } |
378 | |
379 | Type VisitType_(const IncompleteTypeNode* op) override { |
380 | auto* node = solver_->GetTypeNode(GetRef<IncompleteType>(op)); |
381 | return node->resolved_type; |
382 | } |
383 | |
384 | private: |
385 | TypeSolver* solver_; |
386 | }; |
387 | |
388 | // It ends up being more compact to simply have TypeFunctor<void(const Type&) than |
389 | // a TypeVisitor because we can use the default case to dispense with |
390 | // most of the overrides. |
391 | class TypeSolver::Propagator : public TypeFunctor<void(const Type&)> { |
392 | public: |
393 | explicit Propagator(TypeSolver* solver, const std::unordered_set<RelationNode*>* rels) |
394 | : solver_(solver), rels_(rels) {} |
395 | |
396 | // adds the relation node to t and all child types of t |
397 | void Propagate(const Type& t) { VisitType(t); } |
398 | |
399 | void UpdateRelSet(const Type& t) { |
400 | TypeNode* tnode = solver_->GetTypeNode(t); |
401 | for (auto* rel : *rels_) { |
402 | tnode->rel_set.insert(rel); |
403 | } |
404 | } |
405 | |
406 | void VisitTypeDefault_(const Object* op) override { |
407 | ObjectRef nr = GetRef<ObjectRef>(op); |
408 | Type t = GetRef<Type>(nr.as<tvm::relay::TypeNode>()); |
409 | UpdateRelSet(t); |
410 | } |
411 | |
412 | void VisitType_(const TupleTypeNode* op) override { |
413 | TupleType tt = GetRef<TupleType>(op); |
414 | UpdateRelSet(tt); |
415 | |
416 | for (const Type& t : tt->fields) { |
417 | Propagate(t); |
418 | } |
419 | } |
420 | |
421 | void VisitType_(const FuncTypeNode* op) override { |
422 | FuncType ft = GetRef<FuncType>(op); |
423 | UpdateRelSet(ft); |
424 | |
425 | Propagate(ft->ret_type); |
426 | for (auto arg_type : ft->arg_types) { |
427 | Propagate(arg_type); |
428 | } |
429 | |
430 | for (auto type_param : ft->type_params) { |
431 | Propagate(type_param); |
432 | } |
433 | |
434 | for (auto type_cs : ft->type_constraints) { |
435 | Propagate(type_cs); |
436 | } |
437 | } |
438 | |
439 | void VisitType_(const TypeCallNode* op) override { |
440 | TypeCall tc = GetRef<TypeCall>(op); |
441 | UpdateRelSet(tc); |
442 | |
443 | Propagate(tc->func); |
444 | for (auto arg : tc->args) { |
445 | Propagate(arg); |
446 | } |
447 | } |
448 | |
449 | private: |
450 | TypeSolver* solver_; |
451 | const std::unordered_set<RelationNode*>* rels_; |
452 | }; |
453 | |
454 | // similarly, we use TypeFunctor<void(const Type&)> so we can use |
455 | // the default visitor case to avoid more overrides |
456 | class TypeSolver::Merger : public TypeFunctor<void(const Type&)> { |
457 | public: |
458 | explicit Merger(TypeSolver* solver) : solver_(solver) {} |
459 | |
460 | // Merges src node to dst, ensures *all* type relations of all |
461 | // child nodes of src are transferred to dst. |
462 | void Merge(TypeNode* src, TypeNode* dst) { |
463 | if (src == dst) return; |
464 | dst_ = dst; |
465 | VisitType(src->resolved_type); |
466 | // set parent at the end so later calls to GetTypeNode go back to src |
467 | src->parent = dst; |
468 | |
469 | // now propagate relations to child nodes, since change to |
470 | // a child node should update parent too |
471 | Propagator prop(solver_, &dst->rel_set); |
472 | prop.Propagate(dst->resolved_type); |
473 | } |
474 | |
475 | // Transfers any relations linked to t to the stored dst. |
476 | // Any unresolved relations are added back to the queue, since |
477 | // there is now new information |
478 | void TransferLinks(const Type& t) { |
479 | TypeNode* src = solver_->GetTypeNode(t); |
480 | if (src == dst_) return; |
481 | for (auto* rel : src->rel_set) { |
482 | // if the relation is not yet resolved, add to queue |
483 | if (!rel->resolved) { |
484 | solver_->AddToQueue(rel); |
485 | dst_->rel_set.insert(rel); |
486 | } |
487 | } |
488 | } |
489 | |
490 | void VisitTypeDefault_(const Object* op) override { |
491 | ObjectRef nr = GetRef<ObjectRef>(op); |
492 | Type t = GetRef<Type>(nr.as<tvm::relay::TypeNode>()); |
493 | TransferLinks(t); |
494 | } |
495 | |
496 | void VisitType_(const TupleTypeNode* ttn) override { |
497 | auto tup = GetRef<TupleType>(ttn); |
498 | TransferLinks(tup); |
499 | |
500 | for (auto field : tup->fields) { |
501 | VisitType(field); |
502 | } |
503 | } |
504 | |
505 | void VisitType_(const FuncTypeNode* ftn) override { |
506 | auto func = GetRef<FuncType>(ftn); |
507 | TransferLinks(func); |
508 | |
509 | VisitType(func->ret_type); |
510 | for (auto arg : func->arg_types) { |
511 | VisitType(arg); |
512 | } |
513 | for (auto param : func->type_params) { |
514 | VisitType(param); |
515 | } |
516 | for (auto constraint : func->type_constraints) { |
517 | VisitType(constraint); |
518 | } |
519 | } |
520 | |
521 | private: |
522 | TypeSolver* solver_; |
523 | TypeNode* dst_; |
524 | }; |
525 | |
526 | // constructor |
527 | TypeSolver::TypeSolver(const GlobalVar& current_func, DiagnosticContext diag_ctx) |
528 | : reporter_(make_object<Reporter>(this)), |
529 | current_func_(current_func), |
530 | diag_ctx_(diag_ctx), |
531 | module_(diag_ctx->module) { |
532 | ICHECK(module_.defined()); |
533 | } |
534 | |
535 | // destructor |
536 | TypeSolver::~TypeSolver() { |
537 | // call destructor of all non-POD arena object |
538 | for (TypeNode* ptr : type_nodes_) { |
539 | ptr->~TypeNode(); |
540 | } |
541 | for (RelationNode* ptr : rel_nodes_) { |
542 | ptr->~RelationNode(); |
543 | } |
544 | } |
545 | |
546 | // merge src type node to dst |
547 | void TypeSolver::MergeFromTo(TypeNode* src, TypeNode* dst) { |
548 | Merger merger(this); |
549 | merger.Merge(src, dst); |
550 | } |
551 | |
552 | // Add equality constraint |
553 | Type TypeSolver::Unify(const Type& dst, const Type& src, const Span& span, bool assign_lhs, |
554 | bool assign_rhs) { |
555 | Unifier unifier(this, span); |
556 | return unifier.Unify(dst, src, assign_lhs, assign_rhs); |
557 | } |
558 | |
559 | // Add type constraint to the solver. |
560 | void TypeSolver::AddConstraint(const TypeConstraint& constraint, const Span& span) { |
561 | if (const auto* op = constraint.as<TypeRelationNode>()) { |
562 | // create a new relation node. |
563 | RelationNode* rnode = arena_.make<RelationNode>(); |
564 | rnode->span = span; |
565 | rnode->rel = GetRef<TypeRelation>(op); |
566 | rel_nodes_.push_back(rnode); |
567 | // populate the type information. |
568 | for (size_t i = 0; i < op->args.size(); ++i) { |
569 | // insert link to the type list |
570 | LinkNode<TypeNode*>* tlink = arena_.make<LinkNode<TypeNode*>>(); |
571 | TypeNode* tnode = GetTypeNode(op->args[i]); |
572 | tlink->value = tnode; |
573 | rnode->type_list.Push(tlink); |
574 | // insert type->relation node |
575 | std::unordered_set<RelationNode*> singleton{rnode}; |
576 | Propagator prop(this, &singleton); |
577 | prop.Propagate(tnode->resolved_type); |
578 | } |
579 | // add the relation to the working queue. |
580 | this->AddToQueue(rnode); |
581 | } else { |
582 | LOG(FATAL) << "Do not know how to handle constraint type" << constraint->GetTypeKey(); |
583 | } |
584 | } |
585 | |
586 | // Resolve a type in the solver context. |
587 | Type TypeSolver::Resolve(const Type& type) { |
588 | Resolver resolver(this); |
589 | auto it = tmap_.find(type); |
590 | Type t = (it != tmap_.end()) ? it->second->FindRoot()->resolved_type : type; |
591 | return resolver.Resolve(t); |
592 | } |
593 | |
594 | bool TypeSolver::Solve() { |
595 | while (!update_queue_.empty()) { |
596 | RelationNode* rnode = update_queue_.front(); |
597 | const auto& rel = rnode->rel; |
598 | update_queue_.pop(); |
599 | ICHECK(!rnode->resolved); |
600 | // update the relation with given evidence. |
601 | Array<Type> args; |
602 | for (auto* tlink = rnode->type_list.head; tlink != nullptr; tlink = tlink->next) { |
603 | args.push_back(Resolve(tlink->value->FindRoot()->resolved_type)); |
604 | ICHECK_LE(args.size(), rel->args.size()); |
605 | } |
606 | |
607 | // We need to set this in order to understand where unification |
608 | // errors generated by the error reporting are coming from. |
609 | reporter_->SetSpan(rnode->span); |
610 | |
611 | try { |
612 | // Call the Type Relation's function. |
613 | bool resolved = rel->func(args, rel->num_inputs, rel->attrs, reporter_); |
614 | |
615 | if (resolved) { |
616 | ++num_resolved_rels_; |
617 | } |
618 | |
619 | rnode->resolved = resolved; |
620 | } catch (const CompileError& err) { |
621 | this->Emit(Diagnostic::Error(rnode->span) << err.what()); |
622 | rnode->resolved = false; |
623 | } catch (const Error& e) { |
624 | ICHECK(false) << e.what(); |
625 | } |
626 | |
627 | // Mark inqueue as false after the function call |
628 | // so that rnode itself won't get enqueued again. |
629 | rnode->inqueue = false; |
630 | } |
631 | |
632 | // This criterion is not necessarily right for all the possible cases |
633 | // TODO(tqchen): We should also count the number of in-complete types. |
634 | return num_resolved_rels_ == rel_nodes_.size(); |
635 | } |
636 | |
637 | // Expose type solver only for debugging purposes. |
638 | TVM_REGISTER_GLOBAL("relay.analysis._test_type_solver" ) |
639 | .set_body([](runtime::TVMArgs args, runtime::TVMRetValue* ret) { |
640 | using runtime::PackedFunc; |
641 | using runtime::TypedPackedFunc; |
642 | auto module = IRModule({}, {}); |
643 | DiagnosticContext diag_ctx = DiagnosticContext::Default(module); |
644 | auto dummy_fn_name = GlobalVar("test" ); |
645 | module->Add(dummy_fn_name, Function({}, Tuple(tvm::Array<relay::Expr>({})), Type(), {}, {})); |
646 | auto solver = std::make_shared<TypeSolver>(dummy_fn_name, diag_ctx); |
647 | |
648 | auto mod = [module, solver, diag_ctx](std::string name) -> PackedFunc { |
649 | if (name == "Solve" ) { |
650 | return TypedPackedFunc<bool()>([solver]() { return solver->Solve(); }); |
651 | } else if (name == "Unify" ) { |
652 | return TypedPackedFunc<Type(Type, Type)>([module, solver, diag_ctx](Type lhs, Type rhs) { |
653 | auto res = solver->Unify(lhs, rhs, Span()); |
654 | DiagnosticContext ctx = diag_ctx; |
655 | ctx.Render(); |
656 | return res; |
657 | }); |
658 | } else if (name == "Resolve" ) { |
659 | return TypedPackedFunc<Type(Type)>([solver](Type t) { return solver->Resolve(t); }); |
660 | } else if (name == "AddConstraint" ) { |
661 | return TypedPackedFunc<void(TypeConstraint)>([solver](TypeConstraint c) { |
662 | Expr e = Var("dummy_var" , IncompleteType(Kind::kType), Span(SourceName(), 0, 0, 0, 0)); |
663 | return solver->AddConstraint(c, e->span); |
664 | }); |
665 | } else { |
666 | return PackedFunc(); |
667 | } |
668 | }; |
669 | *ret = runtime::TypedPackedFunc<runtime::PackedFunc(std::string)>(mod); |
670 | }); |
671 | |
672 | } // namespace relay |
673 | } // namespace tvm |
674 | |