1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file type_solver.h |
22 | * \brief Solver logic for type inference. |
23 | */ |
24 | #ifndef TVM_RELAY_ANALYSIS_TYPE_SOLVER_H_ |
25 | #define TVM_RELAY_ANALYSIS_TYPE_SOLVER_H_ |
26 | |
27 | #include <tvm/relay/analysis.h> |
28 | #include <tvm/relay/error.h> |
29 | #include <tvm/relay/expr.h> |
30 | #include <tvm/relay/type.h> |
31 | |
32 | #include <queue> |
33 | #include <unordered_map> |
34 | #include <unordered_set> |
35 | #include <vector> |
36 | |
37 | #include "../../support/arena.h" |
38 | |
39 | namespace tvm { |
40 | namespace relay { |
41 | |
42 | using support::LinkedList; |
43 | using support::LinkNode; |
44 | |
45 | /*! |
46 | * \brief Interface of type solver used in type inference. |
47 | * |
48 | * TypeSolver works on a list of constraints among incomplete types. |
49 | * The user will populate the constraints by AddConstraint and Assign. |
50 | * Then we can call Solve to trying to resolve the unknown. |
51 | * |
52 | * This can be viewed as "type program(computational graph)" of types, where |
53 | * the type constraint are operators of the graph and the incomplete |
54 | * types are intermediate value of the graph. |
55 | * If all the input types are concretely known, we should be able to |
56 | * just run a forward pass on the "type program" to get all the types. |
57 | * |
58 | * The list of constraints representation means we are storing it as a bipartite |
59 | * graph instead of a DAG. This is because some constraints might go both direction. |
60 | * TypeSolver could take advantage of bidirectional constraints to deduce input |
61 | * value given output ones. Never-the-less, we should keep in mind that |
62 | * there is a "forward direction" that the TypeSolver should take advantage of. |
63 | */ |
64 | class TypeSolver { |
65 | public: |
66 | TypeSolver(const GlobalVar& current_func, DiagnosticContext diag_ctx); |
67 | ~TypeSolver(); |
68 | /*! |
69 | * \brief Add a type constraint to the solver. |
70 | * \param constraint The constraint to be added. |
71 | * \param location The location at which the constraint was incurred. |
72 | */ |
73 | void AddConstraint(const TypeConstraint& constraint, const Span& span); |
74 | /*! |
75 | * \brief Resolve type to the solution type in the solver. |
76 | * \param type The type to be resolved. |
77 | * \return The resolved type. |
78 | */ |
79 | Type Resolve(const Type& type); |
80 | /*! |
81 | * \brief Start to solve the types using the current known information. |
82 | * \return Whether all the incomplete types has been fully resolved. |
83 | */ |
84 | bool Solve(); |
85 | /*! |
86 | * \brief Unify lhs and rhs. |
87 | * \param lhs The left operand. |
88 | * \param rhs The right operand |
89 | * \param location The location at which the unification problem arose. |
90 | */ |
91 | Type Unify(const Type& lhs, const Type& rhs, const Span& span, bool assign_lhs = true, |
92 | bool assign_rhs = true); |
93 | /*! |
94 | * \brief Report a diagnostic. |
95 | * \param diag The diagnostic to report. |
96 | */ |
97 | void Emit(const Diagnostic& diag) { diag_ctx_.Emit(diag); } |
98 | |
99 | private: |
100 | class OccursChecker; |
101 | class Unifier; |
102 | class Resolver; |
103 | class Propagator; |
104 | class Merger; |
105 | class Reporter; |
106 | struct TypeNode; |
107 | struct RelationNode; |
108 | // Internally the solver maintains a bipartite graph of Relation and Types. |
109 | // All the object in the structure is managed by a arena allocator |
110 | // which releases the memory upon distruction of the type solver. |
111 | /*! |
112 | * \brief type node struct |
113 | * TypeNode implements a union-find data structure(via parent) |
114 | * that can unifies the same types to the name resolved_type. |
115 | * |
116 | * It also contains collection of links to related Relations, |
117 | * which is stored in rel_set. |
118 | */ |
119 | struct TypeNode { |
120 | /*! \brief The final resolved type */ |
121 | Type resolved_type; |
122 | /*! \brief type node in the union find algorithm */ |
123 | TypeNode* parent{nullptr}; |
124 | /*! \brief set of relations that is related to this type node */ |
125 | std::unordered_set<RelationNode*> rel_set; |
126 | |
127 | /*! |
128 | * \brief Find the root type node, perform path compression |
129 | * \return The root type node. |
130 | */ |
131 | TypeNode* FindRoot() { |
132 | // fast path |
133 | if (this->parent == nullptr) return this; |
134 | // slow path with path compression. |
135 | TypeNode* root = this; |
136 | while (root->parent != nullptr) { |
137 | root = root->parent; |
138 | } |
139 | for (TypeNode* p = this; p != root;) { |
140 | TypeNode* parent = p->parent; |
141 | p->parent = root; |
142 | p = parent; |
143 | } |
144 | return root; |
145 | } |
146 | }; |
147 | |
148 | /*! \brief relation node */ |
149 | struct RelationNode { |
150 | /*! \brief Whether the relation is in the queue to be solved */ |
151 | bool inqueue{false}; |
152 | /*! \brief Whether the relation is resolved */ |
153 | bool resolved{false}; |
154 | /*! \brief The corresponding type relation */ |
155 | TypeRelation rel; |
156 | /*! \brief list types to this relation */ |
157 | LinkedList<TypeNode*> type_list; |
158 | /*! \brief The location this type relation originated from. */ |
159 | Span span; |
160 | }; |
161 | |
162 | /*! \brief A simple union find between shapes. */ |
163 | tvm::Map<IndexExpr, IndexExpr> shape_uf_; |
164 | /*! \brief List of all allocated type nodes */ |
165 | std::vector<TypeNode*> type_nodes_; |
166 | /*! \brief List of all allocated relation nodes */ |
167 | std::vector<RelationNode*> rel_nodes_; |
168 | /*! \brief Number of resolved relations */ |
169 | size_t num_resolved_rels_{0}; |
170 | /*! \brief map from types to type nodes. */ |
171 | std::unordered_map<Type, TypeNode*, ObjectPtrHash, ObjectPtrEqual> tmap_; |
172 | /*! \brief Internal queue to update the relation */ |
173 | std::queue<RelationNode*> update_queue_; |
174 | /*! \brief allocator of all the internal node obhect*/ |
175 | support::Arena arena_; |
176 | /*! \brief Reporter that reports back to self */ |
177 | TypeReporter reporter_; |
178 | /*! \brief The global representing the current function. */ |
179 | GlobalVar current_func_; |
180 | /*! \brief The diagnostic context. */ |
181 | DiagnosticContext diag_ctx_; |
182 | /*! \brief The module. */ |
183 | IRModule module_; |
184 | |
185 | /*! |
186 | * \brief GetTypeNode that is corresponds to t. |
187 | * if it do not exist, create a new one. |
188 | * \return The type node. |
189 | */ |
190 | TypeNode* GetTypeNode(const Type& t) { |
191 | auto it = tmap_.find(t); |
192 | if (it != tmap_.end()) { |
193 | return it->second->FindRoot(); |
194 | } else { |
195 | TypeNode* n = arena_.make<TypeNode>(); |
196 | type_nodes_.push_back(n); |
197 | n->resolved_type = t; |
198 | tmap_[t] = n; |
199 | return n; |
200 | } |
201 | } |
202 | /*! |
203 | * \brief Add relation node rel to the update queue |
204 | * \param rel The relation node |
205 | */ |
206 | void AddToQueue(RelationNode* rel) { |
207 | if (rel->inqueue) return; |
208 | ICHECK(!rel->resolved); |
209 | rel->inqueue = true; |
210 | update_queue_.push(rel); |
211 | } |
212 | |
213 | /*! |
214 | * \brief Merge rhs type node to lhs |
215 | * \param src The source operand |
216 | * \param dst The dst operand. |
217 | */ |
218 | void MergeFromTo(TypeNode* src, TypeNode* dst); |
219 | }; |
220 | |
221 | } // namespace relay |
222 | } // namespace tvm |
223 | #endif // TVM_RELAY_ANALYSIS_TYPE_SOLVER_H_ |
224 | |