1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/arith/int_solver.h |
22 | * \brief integer constraints data structures and solvers |
23 | */ |
24 | #ifndef TVM_ARITH_INT_SOLVER_H_ |
25 | #define TVM_ARITH_INT_SOLVER_H_ |
26 | |
27 | #include <tvm/ir/expr.h> |
28 | #include <tvm/tir/expr.h> |
29 | #include <tvm/tir/op.h> |
30 | |
31 | #include <unordered_map> |
32 | #include <utility> |
33 | #include <vector> |
34 | |
35 | #include "analyzer.h" |
36 | |
37 | namespace tvm { |
38 | namespace arith { |
39 | |
40 | using tir::IterVar; |
41 | using tir::Var; |
42 | using tir::VarNode; |
43 | |
44 | // According to experiments two best simplifications orders were can->rw and rw->can->rw, |
45 | // but rw->can->rw is better for a couple of cases. |
46 | // Also we should end with rw because it factors multipliers out. |
47 | constexpr int kSimplifyRewriteCanonicalRewrite = 3; |
48 | |
49 | /*! |
50 | * \brief Represent integer grouped bounds which are classified into |
51 | * lower bounds (inclusive), upper bounds (inclusive) and equalities. |
52 | * It also contains coefficient as a multiplier for the bounds, i.e., |
53 | * coef * var >= lower |
54 | * coef * var == equal |
55 | * coef * var <= upper |
56 | * \sa IntGroupBounds |
57 | */ |
58 | class IntGroupBoundsNode : public Object { |
59 | public: |
60 | PrimExpr coef; |
61 | Array<PrimExpr> lower; |
62 | Array<PrimExpr> equal; |
63 | Array<PrimExpr> upper; |
64 | |
65 | void VisitAttrs(tvm::AttrVisitor* v) { |
66 | v->Visit("coef" , &coef); |
67 | v->Visit("lower" , &lower); |
68 | v->Visit("equal" , &equal); |
69 | v->Visit("upper" , &upper); |
70 | } |
71 | |
72 | bool SEqualReduce(const IntGroupBoundsNode* other, SEqualReducer eq) const { |
73 | return eq(coef, other->coef) && eq(lower, other->lower) && eq(equal, other->equal) && |
74 | eq(upper, other->upper); |
75 | } |
76 | |
77 | void SHashReduce(SHashReducer hash_reduce) const { |
78 | hash_reduce(coef); |
79 | hash_reduce(lower); |
80 | hash_reduce(equal); |
81 | hash_reduce(upper); |
82 | } |
83 | |
84 | static constexpr const bool _type_has_method_sequal_reduce = true; |
85 | static constexpr const char* _type_key = "arith.IntGroupBounds" ; |
86 | TVM_DECLARE_FINAL_OBJECT_INFO(IntGroupBoundsNode, Object); |
87 | }; |
88 | |
89 | /*! |
90 | * \brief Managed reference to IntGroupBoundsNode. |
91 | * \sa IntGroupBoundsNode |
92 | */ |
93 | class IntGroupBounds : public ObjectRef { |
94 | public: |
95 | /*! |
96 | * \brief Constructor by fields |
97 | * \param coef The coefficient. Must be integer. |
98 | * coef * var >= lower |
99 | * coef * var == equal |
100 | * coef * var >= upper |
101 | * \param lower the lower bounds (include) |
102 | * \param equal equalities |
103 | * \param upper the upper bounds (include) |
104 | */ |
105 | TVM_DLL IntGroupBounds(PrimExpr coef, Array<PrimExpr> lower, Array<PrimExpr> equal, |
106 | Array<PrimExpr> upper); |
107 | |
108 | /*! |
109 | * \brief Construct bounds from a range. |
110 | * \param r The range |
111 | * \return constructed bounds. |
112 | */ |
113 | static IntGroupBounds FromRange(const Range& r); |
114 | |
115 | /*! |
116 | * \brief Perform substitution on all components of the struct. |
117 | */ |
118 | IntGroupBounds Substitute(const Map<Var, PrimExpr>& subst) const; |
119 | |
120 | /*! |
121 | * \brief Find the best range from the grouped bounds. |
122 | * \param vranges_addl additional variable ranges that help infer the best range. |
123 | * \return The best range (has the least difference between the lower bound and upper bound). |
124 | * undefined if (-inf, +inf). |
125 | */ |
126 | Range FindBestRange(const Map<Var, Range>& vranges_addl = {}) const; |
127 | |
128 | /*! |
129 | * \brief Combine the bounds with another range. |
130 | * \param r range to be combined. |
131 | * \return combined bounds. |
132 | */ |
133 | IntGroupBounds operator+(const Range& r); |
134 | |
135 | TVM_DEFINE_OBJECT_REF_METHODS(IntGroupBounds, ObjectRef, IntGroupBoundsNode); |
136 | }; |
137 | |
138 | /*! |
139 | * \brief Represent integer constrains including (integer) variables, their ranges and |
140 | * the relations between them (either equations or inequalities). |
141 | * \sa LinearSystem |
142 | */ |
143 | class IntConstraintsNode : public Object { |
144 | public: |
145 | // e.g., \alpha, \beta, must be integers |
146 | Array<Var> variables; |
147 | // e.g., 1 <= \alpha <= N, etc. |
148 | // it is absolutely ok to include ranges for parameters |
149 | // (variables that are not in this->variables) in this map |
150 | Map<Var, Range> ranges; |
151 | // linear equalities or inequalities |
152 | // e.g., A \alpha = \beta or A \alpha <= \beta |
153 | Array<PrimExpr> relations; |
154 | |
155 | void VisitAttrs(tvm::AttrVisitor* v) { |
156 | v->Visit("variables" , &variables); |
157 | v->Visit("ranges" , &ranges); |
158 | v->Visit("relations" , &relations); |
159 | } |
160 | |
161 | bool SEqualReduce(const IntConstraintsNode* other, SEqualReducer equal) const { |
162 | return equal(variables, other->variables) && equal(ranges, other->ranges) && |
163 | equal(relations, other->relations); |
164 | } |
165 | |
166 | void SHashReduce(SHashReducer hash_reduce) const { |
167 | hash_reduce(variables); |
168 | hash_reduce(ranges); |
169 | hash_reduce(relations); |
170 | } |
171 | |
172 | static constexpr const bool _type_has_method_sequal_reduce = true; |
173 | static constexpr const char* _type_key = "arith.IntConstraints" ; |
174 | TVM_DECLARE_FINAL_OBJECT_INFO(IntConstraintsNode, Object); |
175 | }; |
176 | |
177 | /*! |
178 | * \brief Managed reference to IntConstraintsNode. |
179 | * \sa IntConstraintsNode |
180 | */ |
181 | class IntConstraints : public ObjectRef { |
182 | public: |
183 | /*! |
184 | * \brief Constructor by fields |
185 | * \param variables The variables in the constraints, must be integers. |
186 | * \param ranges The ranges of the variables. |
187 | * \param relations The linear relations between the variables |
188 | * (either equations or inequalities) |
189 | */ |
190 | TVM_DLL IntConstraints(Array<Var> variables, Map<Var, Range> ranges, Array<PrimExpr> relations); |
191 | |
192 | TVM_DEFINE_OBJECT_REF_METHODS(IntConstraints, ObjectRef, IntConstraintsNode); |
193 | }; |
194 | |
195 | /*! |
196 | * \brief We can have different set of variables to represent the same constraints. |
197 | * For example, the following two systems are equivalent, |
198 | * {a + b = 0 | a >= 0, b >= 0} and |
199 | * {m - n = 0 | m >= 0, n <= 0} |
200 | * This data structure represents the transformation |
201 | * between two equivalent linear systems. |
202 | * In the above example, |
203 | * src : {a + b = 0 | a >= 0, b >= 0} |
204 | * dst : {m - n = 0 | m >= 0, n <= 0} |
205 | * src_to_dst : {a -> m, b -> -n} |
206 | * dst_to_src : {m -> a, n -> -b} |
207 | * \sa IntConstraintsTransform |
208 | */ |
209 | class IntConstraintsTransformNode : public Object { |
210 | public: |
211 | IntConstraints src; |
212 | IntConstraints dst; |
213 | Map<Var, PrimExpr> src_to_dst; |
214 | Map<Var, PrimExpr> dst_to_src; |
215 | |
216 | void VisitAttrs(tvm::AttrVisitor* v) { |
217 | v->Visit("src" , &src); |
218 | v->Visit("dst" , &dst); |
219 | v->Visit("src_to_dst" , &src_to_dst); |
220 | v->Visit("dst_to_src" , &dst_to_src); |
221 | } |
222 | |
223 | bool SEqualReduce(const IntConstraintsTransformNode* other, SEqualReducer equal) const { |
224 | return equal(src, other->src) && equal(dst, other->dst) && |
225 | equal(src_to_dst, other->src_to_dst) && equal(dst_to_src, other->dst_to_src); |
226 | } |
227 | |
228 | void SHashReduce(SHashReducer hash_reduce) const { |
229 | hash_reduce(src); |
230 | hash_reduce(dst); |
231 | hash_reduce(src_to_dst); |
232 | hash_reduce(dst_to_src); |
233 | } |
234 | |
235 | static constexpr const bool _type_has_method_sequal_reduce = true; |
236 | static constexpr const char* _type_key = "arith.IntConstraintsTransform" ; |
237 | TVM_DECLARE_FINAL_OBJECT_INFO(IntConstraintsTransformNode, Object); |
238 | }; |
239 | |
240 | /*! |
241 | * \brief Managed reference to IntConstraintsTransformNode. |
242 | * \sa IntConstraintsTransformNode |
243 | */ |
244 | class IntConstraintsTransform : public ObjectRef { |
245 | public: |
246 | /*! |
247 | * \brief Constructor by fields |
248 | * \param src source integer constraints, e.g., {a + b = 0 | a >= 0, b >= 0} |
249 | * \param dst integer constraints equivalent to the source, |
250 | * e.g., {m - n = 0 | m >= 0, n <= 0} |
251 | * \param src_to_dst mapping from variables in the \p src to the variables in the \p dst, |
252 | * e.g., {a -> m, b -> -n} |
253 | * \param dst_to_src mapping from variables in the \p dst to the variables in the \p src, |
254 | * e.g., {m -> a, n -> -b} |
255 | */ |
256 | TVM_DLL IntConstraintsTransform(IntConstraints src, IntConstraints dst, |
257 | Map<Var, PrimExpr> src_to_dst, Map<Var, PrimExpr> dst_to_src); |
258 | |
259 | /*! |
260 | * \brief Chain-compose two IntConstraintsTransform together. |
261 | * this->dst must be the same as other->src. |
262 | * @param other another IntConstraintsTransform whose src is same as this->dst. |
263 | * @return composed IntConstraintsTransform(this->src, other->dst) |
264 | * with its variables and ranges are properly modified. |
265 | */ |
266 | IntConstraintsTransform operator+(const IntConstraintsTransform& other) const; |
267 | |
268 | TVM_DEFINE_OBJECT_REF_METHODS(IntConstraintsTransform, ObjectRef, IntConstraintsTransformNode); |
269 | }; |
270 | |
271 | typedef std::pair<Map<Var, IntGroupBounds>, Array<PrimExpr>> PartialSolvedInequalities; |
272 | |
273 | /*! |
274 | * \brief Obtain Smith Normal Form of linear equation A x = y. |
275 | * Smith Normal Form of matrix A_{mxn} is S_{mxn} = U_{mxm} A_{mxn} V_{nxn}, |
276 | * in which S_{mxn} is diag(s1, s2, ..., sr, 0, ..., 0) and r is the rank of A. |
277 | * NOTE: Although in standard Smith Normal Form the diagonal elements satisfy |
278 | * s_i | s_{i+1} (| means divides), the implement here does not guarantee it. |
279 | * TODO(yzhliu): From sergei-grechanik: |
280 | * computing the proper Smith normal form may improve stability of automatic |
281 | * differentiation (generating the same gradient code for slightly different but equivalent input |
282 | * code U_{mxm} and V_{nxn} are invertible matrices. This function modifies \p S to be S_{mxn}, \p V |
283 | * to be V_{nxn}, \p y to be U_{mxm} y_{mx1} and \p x to be V^{-1} x. \param S the original |
284 | * A_{mxn}, it will be modified to S_{mxn} \param V an identity matrix, it will be modified to |
285 | * V_{nxn} \param x the x in A x = y. it will be modified to V^{-1}_{nxn} x_{nx1} \param y the y |
286 | * in A x = y. it will be modified to U_{mxm} y_{mx1} |
287 | */ |
288 | void SmithNormalFormDiag(std::vector<std::vector<int64_t>>* S, std::vector<std::vector<int64_t>>* V, |
289 | std::vector<PrimExpr>* x, std::vector<PrimExpr>* y); |
290 | |
291 | /*! |
292 | * \brief Solve linear equations. |
293 | * \param system_to_solve the variables to solve, their ranges, and a list of equations. |
294 | * \return A new linear system, with less variables (if \p system_to_solve is NOT of full rank), |
295 | * or no variable (if \p system_to_solve is of full rank), |
296 | * or an empty linear system (if \p system_to_solve is unsolvable). |
297 | * It also provides the ranges of the variables in the new system, |
298 | * as well as inequalities inferred from the \p system_to_solve. |
299 | * You can get the mapping from the original variables to the solution via ret->src_to_dst. |
300 | */ |
301 | IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_solve); |
302 | |
303 | /*! |
304 | * \brief Solve linear inequalities. |
305 | * \param system_to_solve the variables to solve, their ranges, and a list of inequalities. |
306 | * The inequalities are rewritten using Fourier-Motzkin elimination. |
307 | * This function takes an array of (in)equalities and an array of variables, and essentially |
308 | * rewrites the (in)equalities into an array of (in)equalities of the following form, |
309 | * |
310 | * x0 >= f0(x1, x2, ..., xn) |
311 | * x0 <= g0(x1, x2, ..., xn) |
312 | * x1 >= f1(x2, ..., xn) |
313 | * x1 <= g1(x2, ..., xn) |
314 | * ... |
315 | * xn >= fn() // just a constant |
316 | * xn <= gn() // just a constant |
317 | * |
318 | * \return A map of variables and their solved bounds, |
319 | * and constrains that cannot be solved to bounds. |
320 | */ |
321 | PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_to_solve); |
322 | |
323 | /*! |
324 | * \brief Combine the information into an array of (in)equalities. |
325 | * \param variables The variables in \p bounds. |
326 | * It is used to determine the iteration order to avoid indeterministic results. |
327 | * \param bounds grouped boundary of the variables. |
328 | * \param relations other relations. |
329 | */ |
330 | Array<PrimExpr> AsConditions(const Array<Var>& variables, const Map<Var, IntGroupBounds>& bounds, |
331 | const Array<PrimExpr>& relations); |
332 | |
333 | /*! |
334 | * \brief Solve linear inequalities and infer the range of each variable. |
335 | * \param system_to_solve the variables to solve, their ranges, and a list of inequalities. |
336 | * \return The result ranges for each variables. |
337 | * The returned IntConstraints(variables, ranges, relations) contains, |
338 | * 1. variables - the variables that have been solved. |
339 | * 2. ranges - the best range of each variable. |
340 | * 3. relations - constraints that cannot be transformed to |
341 | * Range will be stored in relations. |
342 | */ |
343 | IntConstraints SolveInequalitiesToRange(const IntConstraints& system_to_solve); |
344 | |
345 | /*! |
346 | * \brief Solve linear inequalities and deskew the ranges towards zero. |
347 | * \param system_to_solve the variables to solve, their ranges, and a list of inequalities. |
348 | * \return A transform (src IntConstraints -> dst IntConstraints) |
349 | * from original variables to a set of new variables. |
350 | * The ranges of new variables always start from zero, |
351 | * their extents are solved from \p system_to_solve. |
352 | * src IntConstraints is the same as \p system_to_solve. |
353 | * dst IntConstraints(variables, ranges, relations) contains, |
354 | * 1. variables - the variables that have been solved. |
355 | * 2. ranges - the best range (start from zero) of each variable. |
356 | * 3. relations - constraints that cannot be transformed to |
357 | * Range will be stored in relations. |
358 | * Variable mapping can be obtained from |
359 | * IntConstraintsTransform.src_to_dst and IntConstraintsTransform.dst_to_src. |
360 | */ |
361 | IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& system_to_solve); |
362 | |
363 | } // namespace arith |
364 | } // namespace tvm |
365 | #endif // TVM_ARITH_INT_SOLVER_H_ |
366 | |