1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/arith/solve_linear_equation.cc
22 * \brief Solve linear equations.
23 */
24#include <tvm/arith/analyzer.h>
25#include <tvm/arith/int_solver.h>
26#include <tvm/arith/pattern.h>
27#include <tvm/runtime/data_type.h>
28#include <tvm/runtime/registry.h>
29#include <tvm/tir/expr.h>
30#include <tvm/tir/op.h>
31#include <tvm/tir/stmt_functor.h>
32
33#include "int_operator.h"
34
35namespace tvm {
36namespace arith {
37
38using namespace tvm::runtime;
39
40void SmithNormalFormDiag(std::vector<std::vector<int64_t>>* S, std::vector<std::vector<int64_t>>* V,
41 std::vector<PrimExpr>* x, std::vector<PrimExpr>* y) {
42 if (S->empty() || V->empty()) return;
43 size_t m = S->size();
44 size_t n = (*S)[0].size(); // n is # of variables
45 ICHECK_EQ(V->size(), n);
46 ICHECK_EQ((*V)[0].size(), n);
47
48 for (size_t index = 0; index < std::min(m, n); ++index) {
49 // Here A is partially diagonalized, that is A[i, j] is zero for all i, j
50 // such that (i < index) or (j < index), unless (i == j).
51 // That is, now we are diagonalizing the submatrix with i >= index and j >= index
52
53 // Find a row with a nonzero element in the index-th column
54 // (We also prefer rows where this element has minimal abs value)
55 size_t best_i = index;
56 for (size_t i = best_i; i < m; ++i) {
57 int64_t s_old = (*S)[best_i][index];
58 int64_t s_new = (*S)[i][index];
59 if (s_new != 0) {
60 if (s_old == 0 || std::abs(s_new) < std::abs(s_old)) {
61 best_i = i;
62 }
63 }
64 }
65 // Move the row we found to the index-th position
66 std::swap((*S)[index], (*S)[best_i]);
67 std::swap((*y)[index], (*y)[best_i]);
68
69 // If the index-th diagonal element is still zero, try to find a column with nonzero index-th
70 // element and move it to the index-th position
71 if ((*S)[index][index] == 0) {
72 for (size_t j = index + 1; j < n; ++j) {
73 if ((*S)[index][j] != 0) {
74 for (size_t i = index; i < m; ++i) {
75 std::swap((*S)[i][index], (*S)[i][j]);
76 }
77 // swapping columns corresponds to swapping the corresponding x
78 std::swap((*x)[index], (*x)[j]);
79 for (size_t i = 0; i < n; ++i) {
80 std::swap((*V)[i][index], (*V)[i][j]);
81 }
82 break;
83 }
84 }
85 }
86
87 // If the index-th diagonal element is still zero, then both the index-th row and the index-th
88 // column are completely zero, and we don't need to do anything; just go to the next index
89 if ((*S)[index][index] == 0) {
90 continue;
91 }
92
93 // Now the index-th diagonal element is non-zero and we can zero all the index-th column
94 // below it by subtracting rows from each other
95 for (auto i = index + 1; i < m; ++i) {
96 if ((*S)[i][index] != 0) {
97 int64_t g, a, b;
98 // g = a*matrix[index][index] + b*matrix[i][index]
99 if ((*S)[i][index] % (*S)[index][index] != 0) {
100 g = ExtendedEuclidean((*S)[index][index], (*S)[i][index], &a, &b);
101 } else {
102 // Explicitly avoid changing the index-th row. This is important to avoid infinite loop.
103 g = (*S)[index][index];
104 a = 1;
105 b = 0;
106 }
107
108 // Let m = S[index][index], n = S[i][index], then the following is true:
109 //
110 // [ a n/g ][ m/g n/g ] = [ 1 0 ]
111 // [ b -m/g ][ b -a ] = [ 0 1 ]
112 //
113 // Note that the two matrices are integer (since g = gcd(m, n)).
114 // We will essentially multiply our matrix on the left by a dilated and transposed version
115 // of the first of these two matrices. The second matrix is not needed here, however we will
116 // use it while zeroing the index-th row.
117
118 int64_t m_g = (*S)[index][index] / g;
119 int64_t n_g = (*S)[i][index] / g;
120
121 // Note that j is the index of the column, not the row
122 for (size_t j = index; j < (*S)[i].size(); ++j) {
123 // Multiply index-th row by a and add the i-th row multiplied by b
124 // This will make the index-th diagonal element equal to the gcd
125 int64_t new_index_j = a * (*S)[index][j] + b * (*S)[i][j];
126 // This transformation performs zeroing of matrix[i][index]
127 int64_t new_i_j = n_g * (*S)[index][j] - m_g * (*S)[i][j];
128 (*S)[index][j] = new_index_j;
129 (*S)[i][j] = new_i_j;
130 }
131 // We have to do the same with rhs
132 PrimExpr ea = tir::make_const((*y)[index].dtype(), a);
133 PrimExpr eb = tir::make_const((*y)[i].dtype(), b);
134 PrimExpr e_m_g = tir::make_const((*y)[i].dtype(), m_g);
135 PrimExpr e_n_g = tir::make_const((*y)[index].dtype(), n_g);
136 PrimExpr new_index_rhs = ea * (*y)[index] + eb * (*y)[i];
137 PrimExpr new_i_rhs = e_n_g * (*y)[index] - e_m_g * (*y)[i];
138 (*y)[index] = new_index_rhs;
139 (*y)[i] = new_i_rhs;
140 }
141 }
142
143 bool changed = false;
144
145 // Now we have to zero the elements of the index-th row by manipulating columns.
146 // This is more difficult because column manipulation corresponds to variable manipulation,
147 // but the algorithm is essentially the same as before.
148 for (size_t j = index + 1; j < n; ++j) {
149 if ((*S)[index][j] != 0) {
150 int64_t g, a, b;
151 // g = a*matrix[index][index] + b*matrix[index][j]
152 if ((*S)[index][j] % (*S)[index][index] != 0) {
153 g = ExtendedEuclidean((*S)[index][index], (*S)[index][j], &a, &b);
154 // During this phase we may disrupt the zeroness of the index-th column, so we will
155 // have to take some action if this might have happened.
156 changed = true;
157 } else {
158 // Explicitly avoid changing the index-th column. This is important to avoid infinite
159 // loop. Note that here we don't have to set `changed` to true since we don't change the
160 // index-th column.
161 g = (*S)[index][index];
162 a = 1;
163 b = 0;
164 }
165
166 // Let m = S[index][index], n = S[index][j], then the following is true:
167 //
168 // [ a n/g ][ m/g n/g ] = [ 1 0 ]
169 // [ b -m/g ][ b -a ] = [ 0 1 ]
170 //
171 // Now we are going to multiply our matrix on the right (to manipulate columns instead of
172 // rows), we will also transform the old_to_new matrix the same way, and we will use the
173 // second matrix to transform new_to_old.
174
175 int64_t m_g = (*S)[index][index] / g;
176 int64_t n_g = (*S)[index][j] / g;
177
178 for (size_t i = index; i < m; ++i) {
179 int64_t new_i_index = a * (*S)[i][index] + b * (*S)[i][j];
180 int64_t new_i_j = n_g * (*S)[i][index] - m_g * (*S)[i][j];
181 (*S)[i][index] = new_i_index;
182 (*S)[i][j] = new_i_j;
183 }
184 // We do exactly the same transformations with V
185 for (size_t i = 0; i < n; ++i) {
186 int64_t new_i_index = a * (*V)[i][index] + b * (*V)[i][j];
187 int64_t new_i_j = n_g * (*V)[i][index] - m_g * (*V)[i][j];
188 (*V)[i][index] = new_i_index;
189 (*V)[i][j] = new_i_j;
190 }
191 // And apply reverse transformations to new_to_old.
192 PrimExpr ea = tir::make_const((*x)[j].dtype(), a);
193 PrimExpr eb = tir::make_const((*x)[index].dtype(), b);
194 PrimExpr e_m_g = tir::make_const((*x)[index].dtype(), m_g);
195 PrimExpr e_n_g = tir::make_const((*x)[j].dtype(), n_g);
196 PrimExpr new_index = e_m_g * (*x)[index] + e_n_g * (*x)[j];
197 PrimExpr new_j = eb * (*x)[index] - ea * (*x)[j];
198 (*x)[index] = new_index;
199 (*x)[j] = new_j;
200 }
201 }
202
203 if (changed) {
204 // We might have changed the first column, so we have to zero it once more
205 // (or at least check if it's zero), so just perform this iteration once more.
206 index -= 1;
207 }
208 }
209}
210
211Map<Var, Range> InferRange(const Map<Var, PrimExpr>& vars_to_infer, const Array<Var>& ori_vars,
212 const Map<Var, Range>& ori_ranges) {
213 // The resulting ranges
214 Map<Var, Range> new_ranges;
215
216 std::unordered_set<const VarNode*> ori_vset;
217 for (const Var& v : ori_vars) {
218 ori_vset.insert(v.get());
219 }
220
221 std::unordered_map<const VarNode*, IntSet> var_intsets;
222 for (const auto& p : ori_ranges) {
223 if (!ori_vset.count(p.first.get())) {
224 // First of all, fill the new ranges with outer variable ranges
225 new_ranges.Set(p.first, p.second);
226 }
227 // Convert original ranges to IntSets
228 var_intsets[p.first.get()] = IntSet::FromRange(p.second);
229 }
230
231 // Infer ranges for the new variables and add them to the resulting ranges
232 for (const auto& p : vars_to_infer) {
233 const auto& var = p.first;
234 const auto& expr = p.second;
235 Range range = EvalSet(expr, var_intsets).CoverRange(Range());
236 if (range.defined()) {
237 new_ranges.Set(var, range);
238 }
239 }
240 return new_ranges;
241}
242
243// pretty print matrix equation
244void DebugPrint(const std::vector<std::vector<int64_t>>& S,
245 const std::vector<std::vector<int64_t>>& V, const std::vector<PrimExpr>& V_inv_x,
246 const std::vector<PrimExpr>& rhs) {
247 std::cout << "S:\n";
248 for (size_t i = 0; i < S.size(); ++i) {
249 for (auto e : S[i]) {
250 std::cout << e << "\t";
251 }
252 std::cout << "\t->\t" << rhs[i];
253 std::cout << "\n";
254 }
255 std::cout << "V:\n";
256 for (const auto& r : V) {
257 for (auto e : r) {
258 std::cout << e << "\t";
259 }
260 std::cout << "\n";
261 }
262 std::cout << "V_inv x:\n" << Array<PrimExpr>(V_inv_x);
263 std::cout << "\n" << std::endl;
264}
265
266IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_solve) {
267 // m: # of equations
268 // n: # of variables
269 // we first construct A_{mxn} x_{nx1} = y_{mx1}
270 // then get Smith normal form of matrix A,
271 // S_{mxn} = U_{mxm} A_{mxn} V_{nxn}
272 // => U^{-1} S V^{-1} x = y
273 // S V^{-1} x = U y
274 std::vector<PrimExpr> Uy; // mx1
275 std::vector<std::vector<int64_t>> S; // mxn
276 std::vector<std::vector<int64_t>> V; // nxn
277 std::vector<PrimExpr> V_inv_x; // V^{-1} x, nx1
278 // Conditions we don't know what to do with
279 std::vector<PrimExpr> rest;
280
281 Analyzer analyzer_problem;
282 analyzer_problem.Bind(system_to_solve->ranges);
283
284 size_t num_vars = system_to_solve->variables.size();
285
286 // initialize V_{nxn} with identity matrix,
287 // initialize V^{-1} x as x
288 for (size_t i = 0; i < num_vars; ++i) {
289 V.emplace_back(num_vars);
290 V.back()[i] = 1;
291 V_inv_x.push_back(system_to_solve->variables[i]);
292 }
293
294 // Transform formulas into rows of the matrix
295 // S_{mxn} V^{-1}_{nxn} x_{nx1} = U y, in which n is # of variables
296 // here we initialize S_{mxn} to be A, U to be identity matrix.
297 for (const PrimExpr& equation : system_to_solve->relations) {
298 if (const tir::EQNode* eq = equation.as<tir::EQNode>()) {
299 // a-b = sum_{i=0}^{n-1} variables[i] * coeff[i] + coeff[n]
300 Array<PrimExpr> coeffs = arith::DetectLinearEquation(analyzer_problem.Simplify(eq->a - eq->b),
301 system_to_solve->variables);
302 if (!coeffs.empty()) {
303 std::vector<int64_t> row;
304 for (size_t j = 0; j < coeffs.size() - 1; ++j) {
305 PrimExpr c = coeffs[j];
306 if (const IntImmNode* ic = c.as<IntImmNode>()) {
307 row.push_back(ic->value);
308 } else {
309 // elements in matrix S V must be integers
310 // ignore equations that we cannot deal with.
311 LOG(WARNING) << "Cannot deal with non-integer coefficients, ignore equation "
312 << equation;
313 row.clear();
314 break;
315 }
316 }
317
318 if (!row.empty()) {
319 // S V^{-1} (a-b) = Uy
320 // V is identity for now
321 S.push_back(row);
322 Uy.push_back(-coeffs[coeffs.size() - 1]);
323 continue;
324 }
325 }
326 }
327
328 // otherwise
329 rest.push_back(equation);
330 }
331
332 // After diagonalizing, we have
333 // S_{mxn} is the Smith normal form (diagonal matrix)
334 // V_{nxn} is invertible
335 // V_inv_x is V^{-1} \times x
336 // Uy is U \times y
337 SmithNormalFormDiag(&S, &V, &V_inv_x, &Uy);
338
339 Array<Var> new_vars;
340 Array<PrimExpr> new_relations;
341 Map<Var, PrimExpr> new_to_old_map;
342 Map<Var, PrimExpr> old_to_new_map;
343
344 // Simplify right hand sides
345 for (PrimExpr r : Uy) {
346 r = analyzer_problem.Simplify(r);
347 }
348
349 // Create the relations of the existence of a solution
350 for (size_t j = 0; j < S.size(); ++j) {
351 PrimExpr new_relation;
352 if (j >= num_vars || S[j][j] == 0) {
353 // The row of matrix is zero. A solution exists only if the Ub[j] is also zero
354 new_relation = (Uy[j] == 0);
355 } else {
356 // The diagonal element is non-zero. A solution exists only if the diagonal element
357 // is a divisor of the Ub[j]
358 new_relation = (floormod(Uy[j], std::abs(S[j][j])) == 0);
359 }
360 new_relation = analyzer_problem.Simplify(new_relation);
361 if (tir::is_const_int(new_relation, 0)) {
362 // unable to solve the system.
363 return IntConstraintsTransform(system_to_solve,
364 IntConstraints(
365 /*variables=*/{},
366 /*ranges=*/{},
367 /*relations=*/{tir::make_zero(DataType::Bool())}),
368 {}, {});
369 } else if (!tir::is_const_int(new_relation, 1)) {
370 new_relations.push_back(new_relation);
371 }
372 }
373
374 Array<PrimExpr> solution_for_V_inv_x;
375 // Now create new variables or directly solve the equations
376 // suppose the rank of A is r, aka r = # of non-zeros in S
377 // the solution of S_{mxn} V^{-1}_{nxn} x_{nx1} = U b
378 // is
379 // x = (pseudo-inverse of A) b + K_{(n)x(n-r)} z_{n-r}
380 // = V_{nxn} S^{-1}_{nxm} (Ub)_{mxn} + K_{(n)x(n-r)} z_{n-r}
381 // in which K is the right n-r columns of V, z is variable vector
382 // thus,
383 // V^{-1} x = S^{-1}_{nxm} (Ub)_{mxn} +
384 // [[0, ... 0]_{n-r}, ... [0, ..., 0], diag(1, ..., 1)_{(n-r)x(n-r)}] z_{n-r}
385 for (size_t j = 0; j < num_vars; ++j) {
386 if (j >= S.size() || S[j][j] == 0) {
387 // The j-th variable can take any integer value, create a tvm variable for it
388 PrimExpr to_old = analyzer_problem.Simplify(V_inv_x[j]);
389 std::string name_hint = "n" + std::to_string(new_vars.size());
390 if (const VarNode* v_old = to_old.as<VarNode>()) {
391 name_hint += "_" + v_old->name_hint;
392 }
393 Var v = Var(name_hint, V_inv_x[j].dtype());
394 solution_for_V_inv_x.push_back(v);
395 new_vars.push_back(v);
396 new_to_old_map.Set(v, to_old);
397 } else {
398 // The j-th variable is just a single value, don't create a tvm variable
399 // S^{-1}_{nxm} Uy_{mxn}
400 if (S[j][j] >= 0) {
401 PrimExpr a = tir::make_const(Uy[j].dtype(), S[j][j]);
402 solution_for_V_inv_x.push_back(analyzer_problem.Simplify(floordiv(Uy[j], a)));
403 } else {
404 // This is required because some simplifiers
405 // have problems with dividing by negative numbers
406 PrimExpr a = tir::make_const(Uy[j].dtype(), -S[j][j]);
407 solution_for_V_inv_x.push_back(analyzer_problem.Simplify(floordiv(-Uy[j], a)));
408 }
409 }
410 }
411
412 // V V^{-1} x = x
413 for (size_t i = 0; i < num_vars; ++i) {
414 PrimExpr e = tir::make_zero(system_to_solve->variables[i].dtype());
415 for (size_t j = 0; j < num_vars; ++j) {
416 e = e + tir::make_const(e.dtype(), V[i][j]) * solution_for_V_inv_x[j];
417 }
418 e = analyzer_problem.Simplify(e);
419 old_to_new_map.Set(system_to_solve->variables[i], e);
420 }
421
422 // The resulting ranges
423 Map<Var, Range> new_ranges =
424 InferRange(new_to_old_map, system_to_solve->variables, system_to_solve->ranges);
425 Analyzer analyzer_solution;
426 analyzer_solution.Bind(new_ranges);
427
428 // We have to transform ranges of the old variables into relations over new variables because
429 // new ranges are not enough usually.
430 for (const auto& old_var : system_to_solve->variables) {
431 if (system_to_solve->ranges.find(old_var) != system_to_solve->ranges.end()) {
432 const Range& old_range = system_to_solve->ranges.at(old_var);
433 PrimExpr express_by_new_vars = old_to_new_map.at(old_var);
434 PrimExpr lower_cond = analyzer_solution.Simplify(old_range->min <= express_by_new_vars);
435 PrimExpr upper_cond =
436 analyzer_solution.Simplify(express_by_new_vars < old_range->min + old_range->extent);
437 if (!tir::is_const_int(lower_cond, 1)) {
438 new_relations.push_back(lower_cond);
439 }
440 if (!tir::is_const_int(upper_cond, 1)) {
441 new_relations.push_back(upper_cond);
442 }
443 }
444 }
445
446 // Add the rest conditions
447 for (const PrimExpr& cond : rest) {
448 new_relations.push_back(Substitute(cond, old_to_new_map));
449 }
450
451 IntConstraints solution(new_vars, new_ranges, new_relations);
452 IntConstraintsTransform transform(system_to_solve, solution, old_to_new_map, new_to_old_map);
453
454 return transform;
455}
456
457TVM_REGISTER_GLOBAL("arith.SolveLinearEquations").set_body([](TVMArgs args, TVMRetValue* ret) {
458 if (args.size() == 1) {
459 *ret = SolveLinearEquations(args[0]);
460 } else if (args.size() == 3) {
461 IntConstraints problem(args[0], args[1], args[2]);
462 *ret = SolveLinearEquations(problem);
463 } else {
464 LOG(FATAL) << "arith.SolveLinearEquations expects 1 or 3 arguments, gets " << args.size();
465 }
466});
467
468} // namespace arith
469} // namespace tvm
470