1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/arith/int_solver.h
22 * \brief integer constraints data structures and solvers
23 */
24#ifndef TVM_ARITH_INT_SOLVER_H_
25#define TVM_ARITH_INT_SOLVER_H_
26
27#include <tvm/ir/expr.h>
28#include <tvm/tir/expr.h>
29#include <tvm/tir/op.h>
30
31#include <unordered_map>
32#include <utility>
33#include <vector>
34
35#include "analyzer.h"
36
37namespace tvm {
38namespace arith {
39
40using tir::IterVar;
41using tir::Var;
42using tir::VarNode;
43
44// According to experiments two best simplifications orders were can->rw and rw->can->rw,
45// but rw->can->rw is better for a couple of cases.
46// Also we should end with rw because it factors multipliers out.
47constexpr int kSimplifyRewriteCanonicalRewrite = 3;
48
49/*!
50 * \brief Represent integer grouped bounds which are classified into
51 * lower bounds (inclusive), upper bounds (inclusive) and equalities.
52 * It also contains coefficient as a multiplier for the bounds, i.e.,
53 * coef * var >= lower
54 * coef * var == equal
55 * coef * var <= upper
56 * \sa IntGroupBounds
57 */
58class IntGroupBoundsNode : public Object {
59 public:
60 PrimExpr coef;
61 Array<PrimExpr> lower;
62 Array<PrimExpr> equal;
63 Array<PrimExpr> upper;
64
65 void VisitAttrs(tvm::AttrVisitor* v) {
66 v->Visit("coef", &coef);
67 v->Visit("lower", &lower);
68 v->Visit("equal", &equal);
69 v->Visit("upper", &upper);
70 }
71
72 bool SEqualReduce(const IntGroupBoundsNode* other, SEqualReducer eq) const {
73 return eq(coef, other->coef) && eq(lower, other->lower) && eq(equal, other->equal) &&
74 eq(upper, other->upper);
75 }
76
77 void SHashReduce(SHashReducer hash_reduce) const {
78 hash_reduce(coef);
79 hash_reduce(lower);
80 hash_reduce(equal);
81 hash_reduce(upper);
82 }
83
84 static constexpr const bool _type_has_method_sequal_reduce = true;
85 static constexpr const char* _type_key = "arith.IntGroupBounds";
86 TVM_DECLARE_FINAL_OBJECT_INFO(IntGroupBoundsNode, Object);
87};
88
89/*!
90 * \brief Managed reference to IntGroupBoundsNode.
91 * \sa IntGroupBoundsNode
92 */
93class IntGroupBounds : public ObjectRef {
94 public:
95 /*!
96 * \brief Constructor by fields
97 * \param coef The coefficient. Must be integer.
98 * coef * var >= lower
99 * coef * var == equal
100 * coef * var >= upper
101 * \param lower the lower bounds (include)
102 * \param equal equalities
103 * \param upper the upper bounds (include)
104 */
105 TVM_DLL IntGroupBounds(PrimExpr coef, Array<PrimExpr> lower, Array<PrimExpr> equal,
106 Array<PrimExpr> upper);
107
108 /*!
109 * \brief Construct bounds from a range.
110 * \param r The range
111 * \return constructed bounds.
112 */
113 static IntGroupBounds FromRange(const Range& r);
114
115 /*!
116 * \brief Perform substitution on all components of the struct.
117 */
118 IntGroupBounds Substitute(const Map<Var, PrimExpr>& subst) const;
119
120 /*!
121 * \brief Find the best range from the grouped bounds.
122 * \param vranges_addl additional variable ranges that help infer the best range.
123 * \return The best range (has the least difference between the lower bound and upper bound).
124 * undefined if (-inf, +inf).
125 */
126 Range FindBestRange(const Map<Var, Range>& vranges_addl = {}) const;
127
128 /*!
129 * \brief Combine the bounds with another range.
130 * \param r range to be combined.
131 * \return combined bounds.
132 */
133 IntGroupBounds operator+(const Range& r);
134
135 TVM_DEFINE_OBJECT_REF_METHODS(IntGroupBounds, ObjectRef, IntGroupBoundsNode);
136};
137
138/*!
139 * \brief Represent integer constrains including (integer) variables, their ranges and
140 * the relations between them (either equations or inequalities).
141 * \sa LinearSystem
142 */
143class IntConstraintsNode : public Object {
144 public:
145 // e.g., \alpha, \beta, must be integers
146 Array<Var> variables;
147 // e.g., 1 <= \alpha <= N, etc.
148 // it is absolutely ok to include ranges for parameters
149 // (variables that are not in this->variables) in this map
150 Map<Var, Range> ranges;
151 // linear equalities or inequalities
152 // e.g., A \alpha = \beta or A \alpha <= \beta
153 Array<PrimExpr> relations;
154
155 void VisitAttrs(tvm::AttrVisitor* v) {
156 v->Visit("variables", &variables);
157 v->Visit("ranges", &ranges);
158 v->Visit("relations", &relations);
159 }
160
161 bool SEqualReduce(const IntConstraintsNode* other, SEqualReducer equal) const {
162 return equal(variables, other->variables) && equal(ranges, other->ranges) &&
163 equal(relations, other->relations);
164 }
165
166 void SHashReduce(SHashReducer hash_reduce) const {
167 hash_reduce(variables);
168 hash_reduce(ranges);
169 hash_reduce(relations);
170 }
171
172 static constexpr const bool _type_has_method_sequal_reduce = true;
173 static constexpr const char* _type_key = "arith.IntConstraints";
174 TVM_DECLARE_FINAL_OBJECT_INFO(IntConstraintsNode, Object);
175};
176
177/*!
178 * \brief Managed reference to IntConstraintsNode.
179 * \sa IntConstraintsNode
180 */
181class IntConstraints : public ObjectRef {
182 public:
183 /*!
184 * \brief Constructor by fields
185 * \param variables The variables in the constraints, must be integers.
186 * \param ranges The ranges of the variables.
187 * \param relations The linear relations between the variables
188 * (either equations or inequalities)
189 */
190 TVM_DLL IntConstraints(Array<Var> variables, Map<Var, Range> ranges, Array<PrimExpr> relations);
191
192 TVM_DEFINE_OBJECT_REF_METHODS(IntConstraints, ObjectRef, IntConstraintsNode);
193};
194
195/*!
196 * \brief We can have different set of variables to represent the same constraints.
197 * For example, the following two systems are equivalent,
198 * {a + b = 0 | a >= 0, b >= 0} and
199 * {m - n = 0 | m >= 0, n <= 0}
200 * This data structure represents the transformation
201 * between two equivalent linear systems.
202 * In the above example,
203 * src : {a + b = 0 | a >= 0, b >= 0}
204 * dst : {m - n = 0 | m >= 0, n <= 0}
205 * src_to_dst : {a -> m, b -> -n}
206 * dst_to_src : {m -> a, n -> -b}
207 * \sa IntConstraintsTransform
208 */
209class IntConstraintsTransformNode : public Object {
210 public:
211 IntConstraints src;
212 IntConstraints dst;
213 Map<Var, PrimExpr> src_to_dst;
214 Map<Var, PrimExpr> dst_to_src;
215
216 void VisitAttrs(tvm::AttrVisitor* v) {
217 v->Visit("src", &src);
218 v->Visit("dst", &dst);
219 v->Visit("src_to_dst", &src_to_dst);
220 v->Visit("dst_to_src", &dst_to_src);
221 }
222
223 bool SEqualReduce(const IntConstraintsTransformNode* other, SEqualReducer equal) const {
224 return equal(src, other->src) && equal(dst, other->dst) &&
225 equal(src_to_dst, other->src_to_dst) && equal(dst_to_src, other->dst_to_src);
226 }
227
228 void SHashReduce(SHashReducer hash_reduce) const {
229 hash_reduce(src);
230 hash_reduce(dst);
231 hash_reduce(src_to_dst);
232 hash_reduce(dst_to_src);
233 }
234
235 static constexpr const bool _type_has_method_sequal_reduce = true;
236 static constexpr const char* _type_key = "arith.IntConstraintsTransform";
237 TVM_DECLARE_FINAL_OBJECT_INFO(IntConstraintsTransformNode, Object);
238};
239
240/*!
241 * \brief Managed reference to IntConstraintsTransformNode.
242 * \sa IntConstraintsTransformNode
243 */
244class IntConstraintsTransform : public ObjectRef {
245 public:
246 /*!
247 * \brief Constructor by fields
248 * \param src source integer constraints, e.g., {a + b = 0 | a >= 0, b >= 0}
249 * \param dst integer constraints equivalent to the source,
250 * e.g., {m - n = 0 | m >= 0, n <= 0}
251 * \param src_to_dst mapping from variables in the \p src to the variables in the \p dst,
252 * e.g., {a -> m, b -> -n}
253 * \param dst_to_src mapping from variables in the \p dst to the variables in the \p src,
254 * e.g., {m -> a, n -> -b}
255 */
256 TVM_DLL IntConstraintsTransform(IntConstraints src, IntConstraints dst,
257 Map<Var, PrimExpr> src_to_dst, Map<Var, PrimExpr> dst_to_src);
258
259 /*!
260 * \brief Chain-compose two IntConstraintsTransform together.
261 * this->dst must be the same as other->src.
262 * @param other another IntConstraintsTransform whose src is same as this->dst.
263 * @return composed IntConstraintsTransform(this->src, other->dst)
264 * with its variables and ranges are properly modified.
265 */
266 IntConstraintsTransform operator+(const IntConstraintsTransform& other) const;
267
268 TVM_DEFINE_OBJECT_REF_METHODS(IntConstraintsTransform, ObjectRef, IntConstraintsTransformNode);
269};
270
271typedef std::pair<Map<Var, IntGroupBounds>, Array<PrimExpr>> PartialSolvedInequalities;
272
273/*!
274 * \brief Obtain Smith Normal Form of linear equation A x = y.
275 * Smith Normal Form of matrix A_{mxn} is S_{mxn} = U_{mxm} A_{mxn} V_{nxn},
276 * in which S_{mxn} is diag(s1, s2, ..., sr, 0, ..., 0) and r is the rank of A.
277 * NOTE: Although in standard Smith Normal Form the diagonal elements satisfy
278 * s_i | s_{i+1} (| means divides), the implement here does not guarantee it.
279 * TODO(yzhliu): From sergei-grechanik:
280 * computing the proper Smith normal form may improve stability of automatic
281 * differentiation (generating the same gradient code for slightly different but equivalent input
282 * code U_{mxm} and V_{nxn} are invertible matrices. This function modifies \p S to be S_{mxn}, \p V
283 * to be V_{nxn}, \p y to be U_{mxm} y_{mx1} and \p x to be V^{-1} x. \param S the original
284 * A_{mxn}, it will be modified to S_{mxn} \param V an identity matrix, it will be modified to
285 * V_{nxn} \param x the x in A x = y. it will be modified to V^{-1}_{nxn} x_{nx1} \param y the y
286 * in A x = y. it will be modified to U_{mxm} y_{mx1}
287 */
288void SmithNormalFormDiag(std::vector<std::vector<int64_t>>* S, std::vector<std::vector<int64_t>>* V,
289 std::vector<PrimExpr>* x, std::vector<PrimExpr>* y);
290
291/*!
292 * \brief Solve linear equations.
293 * \param system_to_solve the variables to solve, their ranges, and a list of equations.
294 * \return A new linear system, with less variables (if \p system_to_solve is NOT of full rank),
295 * or no variable (if \p system_to_solve is of full rank),
296 * or an empty linear system (if \p system_to_solve is unsolvable).
297 * It also provides the ranges of the variables in the new system,
298 * as well as inequalities inferred from the \p system_to_solve.
299 * You can get the mapping from the original variables to the solution via ret->src_to_dst.
300 */
301IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_solve);
302
303/*!
304 * \brief Solve linear inequalities.
305 * \param system_to_solve the variables to solve, their ranges, and a list of inequalities.
306 * The inequalities are rewritten using Fourier-Motzkin elimination.
307 * This function takes an array of (in)equalities and an array of variables, and essentially
308 * rewrites the (in)equalities into an array of (in)equalities of the following form,
309 *
310 * x0 >= f0(x1, x2, ..., xn)
311 * x0 <= g0(x1, x2, ..., xn)
312 * x1 >= f1(x2, ..., xn)
313 * x1 <= g1(x2, ..., xn)
314 * ...
315 * xn >= fn() // just a constant
316 * xn <= gn() // just a constant
317 *
318 * \return A map of variables and their solved bounds,
319 * and constrains that cannot be solved to bounds.
320 */
321PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_to_solve);
322
323/*!
324 * \brief Combine the information into an array of (in)equalities.
325 * \param variables The variables in \p bounds.
326 * It is used to determine the iteration order to avoid indeterministic results.
327 * \param bounds grouped boundary of the variables.
328 * \param relations other relations.
329 */
330Array<PrimExpr> AsConditions(const Array<Var>& variables, const Map<Var, IntGroupBounds>& bounds,
331 const Array<PrimExpr>& relations);
332
333/*!
334 * \brief Solve linear inequalities and infer the range of each variable.
335 * \param system_to_solve the variables to solve, their ranges, and a list of inequalities.
336 * \return The result ranges for each variables.
337 * The returned IntConstraints(variables, ranges, relations) contains,
338 * 1. variables - the variables that have been solved.
339 * 2. ranges - the best range of each variable.
340 * 3. relations - constraints that cannot be transformed to
341 * Range will be stored in relations.
342 */
343IntConstraints SolveInequalitiesToRange(const IntConstraints& system_to_solve);
344
345/*!
346 * \brief Solve linear inequalities and deskew the ranges towards zero.
347 * \param system_to_solve the variables to solve, their ranges, and a list of inequalities.
348 * \return A transform (src IntConstraints -> dst IntConstraints)
349 * from original variables to a set of new variables.
350 * The ranges of new variables always start from zero,
351 * their extents are solved from \p system_to_solve.
352 * src IntConstraints is the same as \p system_to_solve.
353 * dst IntConstraints(variables, ranges, relations) contains,
354 * 1. variables - the variables that have been solved.
355 * 2. ranges - the best range (start from zero) of each variable.
356 * 3. relations - constraints that cannot be transformed to
357 * Range will be stored in relations.
358 * Variable mapping can be obtained from
359 * IntConstraintsTransform.src_to_dst and IntConstraintsTransform.dst_to_src.
360 */
361IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& system_to_solve);
362
363} // namespace arith
364} // namespace tvm
365#endif // TVM_ARITH_INT_SOLVER_H_
366