1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file type_solver.h
22 * \brief Solver logic for type inference.
23 */
24#ifndef TVM_RELAY_ANALYSIS_TYPE_SOLVER_H_
25#define TVM_RELAY_ANALYSIS_TYPE_SOLVER_H_
26
27#include <tvm/relay/analysis.h>
28#include <tvm/relay/error.h>
29#include <tvm/relay/expr.h>
30#include <tvm/relay/type.h>
31
32#include <queue>
33#include <unordered_map>
34#include <unordered_set>
35#include <vector>
36
37#include "../../support/arena.h"
38
39namespace tvm {
40namespace relay {
41
42using support::LinkedList;
43using support::LinkNode;
44
45/*!
46 * \brief Interface of type solver used in type inference.
47 *
48 * TypeSolver works on a list of constraints among incomplete types.
49 * The user will populate the constraints by AddConstraint and Assign.
50 * Then we can call Solve to trying to resolve the unknown.
51 *
52 * This can be viewed as "type program(computational graph)" of types, where
53 * the type constraint are operators of the graph and the incomplete
54 * types are intermediate value of the graph.
55 * If all the input types are concretely known, we should be able to
56 * just run a forward pass on the "type program" to get all the types.
57 *
58 * The list of constraints representation means we are storing it as a bipartite
59 * graph instead of a DAG. This is because some constraints might go both direction.
60 * TypeSolver could take advantage of bidirectional constraints to deduce input
61 * value given output ones. Never-the-less, we should keep in mind that
62 * there is a "forward direction" that the TypeSolver should take advantage of.
63 */
64class TypeSolver {
65 public:
66 TypeSolver(const GlobalVar& current_func, DiagnosticContext diag_ctx);
67 ~TypeSolver();
68 /*!
69 * \brief Add a type constraint to the solver.
70 * \param constraint The constraint to be added.
71 * \param location The location at which the constraint was incurred.
72 */
73 void AddConstraint(const TypeConstraint& constraint, const Span& span);
74 /*!
75 * \brief Resolve type to the solution type in the solver.
76 * \param type The type to be resolved.
77 * \return The resolved type.
78 */
79 Type Resolve(const Type& type);
80 /*!
81 * \brief Start to solve the types using the current known information.
82 * \return Whether all the incomplete types has been fully resolved.
83 */
84 bool Solve();
85 /*!
86 * \brief Unify lhs and rhs.
87 * \param lhs The left operand.
88 * \param rhs The right operand
89 * \param location The location at which the unification problem arose.
90 */
91 Type Unify(const Type& lhs, const Type& rhs, const Span& span, bool assign_lhs = true,
92 bool assign_rhs = true);
93 /*!
94 * \brief Report a diagnostic.
95 * \param diag The diagnostic to report.
96 */
97 void Emit(const Diagnostic& diag) { diag_ctx_.Emit(diag); }
98
99 private:
100 class OccursChecker;
101 class Unifier;
102 class Resolver;
103 class Propagator;
104 class Merger;
105 class Reporter;
106 struct TypeNode;
107 struct RelationNode;
108 // Internally the solver maintains a bipartite graph of Relation and Types.
109 // All the object in the structure is managed by a arena allocator
110 // which releases the memory upon distruction of the type solver.
111 /*!
112 * \brief type node struct
113 * TypeNode implements a union-find data structure(via parent)
114 * that can unifies the same types to the name resolved_type.
115 *
116 * It also contains collection of links to related Relations,
117 * which is stored in rel_set.
118 */
119 struct TypeNode {
120 /*! \brief The final resolved type */
121 Type resolved_type;
122 /*! \brief type node in the union find algorithm */
123 TypeNode* parent{nullptr};
124 /*! \brief set of relations that is related to this type node */
125 std::unordered_set<RelationNode*> rel_set;
126
127 /*!
128 * \brief Find the root type node, perform path compression
129 * \return The root type node.
130 */
131 TypeNode* FindRoot() {
132 // fast path
133 if (this->parent == nullptr) return this;
134 // slow path with path compression.
135 TypeNode* root = this;
136 while (root->parent != nullptr) {
137 root = root->parent;
138 }
139 for (TypeNode* p = this; p != root;) {
140 TypeNode* parent = p->parent;
141 p->parent = root;
142 p = parent;
143 }
144 return root;
145 }
146 };
147
148 /*! \brief relation node */
149 struct RelationNode {
150 /*! \brief Whether the relation is in the queue to be solved */
151 bool inqueue{false};
152 /*! \brief Whether the relation is resolved */
153 bool resolved{false};
154 /*! \brief The corresponding type relation */
155 TypeRelation rel;
156 /*! \brief list types to this relation */
157 LinkedList<TypeNode*> type_list;
158 /*! \brief The location this type relation originated from. */
159 Span span;
160 };
161
162 /*! \brief A simple union find between shapes. */
163 tvm::Map<IndexExpr, IndexExpr> shape_uf_;
164 /*! \brief List of all allocated type nodes */
165 std::vector<TypeNode*> type_nodes_;
166 /*! \brief List of all allocated relation nodes */
167 std::vector<RelationNode*> rel_nodes_;
168 /*! \brief Number of resolved relations */
169 size_t num_resolved_rels_{0};
170 /*! \brief map from types to type nodes. */
171 std::unordered_map<Type, TypeNode*, ObjectPtrHash, ObjectPtrEqual> tmap_;
172 /*! \brief Internal queue to update the relation */
173 std::queue<RelationNode*> update_queue_;
174 /*! \brief allocator of all the internal node obhect*/
175 support::Arena arena_;
176 /*! \brief Reporter that reports back to self */
177 TypeReporter reporter_;
178 /*! \brief The global representing the current function. */
179 GlobalVar current_func_;
180 /*! \brief The diagnostic context. */
181 DiagnosticContext diag_ctx_;
182 /*! \brief The module. */
183 IRModule module_;
184
185 /*!
186 * \brief GetTypeNode that is corresponds to t.
187 * if it do not exist, create a new one.
188 * \return The type node.
189 */
190 TypeNode* GetTypeNode(const Type& t) {
191 auto it = tmap_.find(t);
192 if (it != tmap_.end()) {
193 return it->second->FindRoot();
194 } else {
195 TypeNode* n = arena_.make<TypeNode>();
196 type_nodes_.push_back(n);
197 n->resolved_type = t;
198 tmap_[t] = n;
199 return n;
200 }
201 }
202 /*!
203 * \brief Add relation node rel to the update queue
204 * \param rel The relation node
205 */
206 void AddToQueue(RelationNode* rel) {
207 if (rel->inqueue) return;
208 ICHECK(!rel->resolved);
209 rel->inqueue = true;
210 update_queue_.push(rel);
211 }
212
213 /*!
214 * \brief Merge rhs type node to lhs
215 * \param src The source operand
216 * \param dst The dst operand.
217 */
218 void MergeFromTo(TypeNode* src, TypeNode* dst);
219};
220
221} // namespace relay
222} // namespace tvm
223#endif // TVM_RELAY_ANALYSIS_TYPE_SOLVER_H_
224