1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/arith/solve_linear_equation.cc |
22 | * \brief Solve linear equations. |
23 | */ |
24 | #include <tvm/arith/analyzer.h> |
25 | #include <tvm/arith/int_solver.h> |
26 | #include <tvm/arith/pattern.h> |
27 | #include <tvm/runtime/data_type.h> |
28 | #include <tvm/runtime/registry.h> |
29 | #include <tvm/tir/expr.h> |
30 | #include <tvm/tir/op.h> |
31 | #include <tvm/tir/stmt_functor.h> |
32 | |
33 | #include "int_operator.h" |
34 | |
35 | namespace tvm { |
36 | namespace arith { |
37 | |
38 | using namespace tvm::runtime; |
39 | |
40 | void SmithNormalFormDiag(std::vector<std::vector<int64_t>>* S, std::vector<std::vector<int64_t>>* V, |
41 | std::vector<PrimExpr>* x, std::vector<PrimExpr>* y) { |
42 | if (S->empty() || V->empty()) return; |
43 | size_t m = S->size(); |
44 | size_t n = (*S)[0].size(); // n is # of variables |
45 | ICHECK_EQ(V->size(), n); |
46 | ICHECK_EQ((*V)[0].size(), n); |
47 | |
48 | for (size_t index = 0; index < std::min(m, n); ++index) { |
49 | // Here A is partially diagonalized, that is A[i, j] is zero for all i, j |
50 | // such that (i < index) or (j < index), unless (i == j). |
51 | // That is, now we are diagonalizing the submatrix with i >= index and j >= index |
52 | |
53 | // Find a row with a nonzero element in the index-th column |
54 | // (We also prefer rows where this element has minimal abs value) |
55 | size_t best_i = index; |
56 | for (size_t i = best_i; i < m; ++i) { |
57 | int64_t s_old = (*S)[best_i][index]; |
58 | int64_t s_new = (*S)[i][index]; |
59 | if (s_new != 0) { |
60 | if (s_old == 0 || std::abs(s_new) < std::abs(s_old)) { |
61 | best_i = i; |
62 | } |
63 | } |
64 | } |
65 | // Move the row we found to the index-th position |
66 | std::swap((*S)[index], (*S)[best_i]); |
67 | std::swap((*y)[index], (*y)[best_i]); |
68 | |
69 | // If the index-th diagonal element is still zero, try to find a column with nonzero index-th |
70 | // element and move it to the index-th position |
71 | if ((*S)[index][index] == 0) { |
72 | for (size_t j = index + 1; j < n; ++j) { |
73 | if ((*S)[index][j] != 0) { |
74 | for (size_t i = index; i < m; ++i) { |
75 | std::swap((*S)[i][index], (*S)[i][j]); |
76 | } |
77 | // swapping columns corresponds to swapping the corresponding x |
78 | std::swap((*x)[index], (*x)[j]); |
79 | for (size_t i = 0; i < n; ++i) { |
80 | std::swap((*V)[i][index], (*V)[i][j]); |
81 | } |
82 | break; |
83 | } |
84 | } |
85 | } |
86 | |
87 | // If the index-th diagonal element is still zero, then both the index-th row and the index-th |
88 | // column are completely zero, and we don't need to do anything; just go to the next index |
89 | if ((*S)[index][index] == 0) { |
90 | continue; |
91 | } |
92 | |
93 | // Now the index-th diagonal element is non-zero and we can zero all the index-th column |
94 | // below it by subtracting rows from each other |
95 | for (auto i = index + 1; i < m; ++i) { |
96 | if ((*S)[i][index] != 0) { |
97 | int64_t g, a, b; |
98 | // g = a*matrix[index][index] + b*matrix[i][index] |
99 | if ((*S)[i][index] % (*S)[index][index] != 0) { |
100 | g = ExtendedEuclidean((*S)[index][index], (*S)[i][index], &a, &b); |
101 | } else { |
102 | // Explicitly avoid changing the index-th row. This is important to avoid infinite loop. |
103 | g = (*S)[index][index]; |
104 | a = 1; |
105 | b = 0; |
106 | } |
107 | |
108 | // Let m = S[index][index], n = S[i][index], then the following is true: |
109 | // |
110 | // [ a n/g ][ m/g n/g ] = [ 1 0 ] |
111 | // [ b -m/g ][ b -a ] = [ 0 1 ] |
112 | // |
113 | // Note that the two matrices are integer (since g = gcd(m, n)). |
114 | // We will essentially multiply our matrix on the left by a dilated and transposed version |
115 | // of the first of these two matrices. The second matrix is not needed here, however we will |
116 | // use it while zeroing the index-th row. |
117 | |
118 | int64_t m_g = (*S)[index][index] / g; |
119 | int64_t n_g = (*S)[i][index] / g; |
120 | |
121 | // Note that j is the index of the column, not the row |
122 | for (size_t j = index; j < (*S)[i].size(); ++j) { |
123 | // Multiply index-th row by a and add the i-th row multiplied by b |
124 | // This will make the index-th diagonal element equal to the gcd |
125 | int64_t new_index_j = a * (*S)[index][j] + b * (*S)[i][j]; |
126 | // This transformation performs zeroing of matrix[i][index] |
127 | int64_t new_i_j = n_g * (*S)[index][j] - m_g * (*S)[i][j]; |
128 | (*S)[index][j] = new_index_j; |
129 | (*S)[i][j] = new_i_j; |
130 | } |
131 | // We have to do the same with rhs |
132 | PrimExpr ea = tir::make_const((*y)[index].dtype(), a); |
133 | PrimExpr eb = tir::make_const((*y)[i].dtype(), b); |
134 | PrimExpr e_m_g = tir::make_const((*y)[i].dtype(), m_g); |
135 | PrimExpr e_n_g = tir::make_const((*y)[index].dtype(), n_g); |
136 | PrimExpr new_index_rhs = ea * (*y)[index] + eb * (*y)[i]; |
137 | PrimExpr new_i_rhs = e_n_g * (*y)[index] - e_m_g * (*y)[i]; |
138 | (*y)[index] = new_index_rhs; |
139 | (*y)[i] = new_i_rhs; |
140 | } |
141 | } |
142 | |
143 | bool changed = false; |
144 | |
145 | // Now we have to zero the elements of the index-th row by manipulating columns. |
146 | // This is more difficult because column manipulation corresponds to variable manipulation, |
147 | // but the algorithm is essentially the same as before. |
148 | for (size_t j = index + 1; j < n; ++j) { |
149 | if ((*S)[index][j] != 0) { |
150 | int64_t g, a, b; |
151 | // g = a*matrix[index][index] + b*matrix[index][j] |
152 | if ((*S)[index][j] % (*S)[index][index] != 0) { |
153 | g = ExtendedEuclidean((*S)[index][index], (*S)[index][j], &a, &b); |
154 | // During this phase we may disrupt the zeroness of the index-th column, so we will |
155 | // have to take some action if this might have happened. |
156 | changed = true; |
157 | } else { |
158 | // Explicitly avoid changing the index-th column. This is important to avoid infinite |
159 | // loop. Note that here we don't have to set `changed` to true since we don't change the |
160 | // index-th column. |
161 | g = (*S)[index][index]; |
162 | a = 1; |
163 | b = 0; |
164 | } |
165 | |
166 | // Let m = S[index][index], n = S[index][j], then the following is true: |
167 | // |
168 | // [ a n/g ][ m/g n/g ] = [ 1 0 ] |
169 | // [ b -m/g ][ b -a ] = [ 0 1 ] |
170 | // |
171 | // Now we are going to multiply our matrix on the right (to manipulate columns instead of |
172 | // rows), we will also transform the old_to_new matrix the same way, and we will use the |
173 | // second matrix to transform new_to_old. |
174 | |
175 | int64_t m_g = (*S)[index][index] / g; |
176 | int64_t n_g = (*S)[index][j] / g; |
177 | |
178 | for (size_t i = index; i < m; ++i) { |
179 | int64_t new_i_index = a * (*S)[i][index] + b * (*S)[i][j]; |
180 | int64_t new_i_j = n_g * (*S)[i][index] - m_g * (*S)[i][j]; |
181 | (*S)[i][index] = new_i_index; |
182 | (*S)[i][j] = new_i_j; |
183 | } |
184 | // We do exactly the same transformations with V |
185 | for (size_t i = 0; i < n; ++i) { |
186 | int64_t new_i_index = a * (*V)[i][index] + b * (*V)[i][j]; |
187 | int64_t new_i_j = n_g * (*V)[i][index] - m_g * (*V)[i][j]; |
188 | (*V)[i][index] = new_i_index; |
189 | (*V)[i][j] = new_i_j; |
190 | } |
191 | // And apply reverse transformations to new_to_old. |
192 | PrimExpr ea = tir::make_const((*x)[j].dtype(), a); |
193 | PrimExpr eb = tir::make_const((*x)[index].dtype(), b); |
194 | PrimExpr e_m_g = tir::make_const((*x)[index].dtype(), m_g); |
195 | PrimExpr e_n_g = tir::make_const((*x)[j].dtype(), n_g); |
196 | PrimExpr new_index = e_m_g * (*x)[index] + e_n_g * (*x)[j]; |
197 | PrimExpr new_j = eb * (*x)[index] - ea * (*x)[j]; |
198 | (*x)[index] = new_index; |
199 | (*x)[j] = new_j; |
200 | } |
201 | } |
202 | |
203 | if (changed) { |
204 | // We might have changed the first column, so we have to zero it once more |
205 | // (or at least check if it's zero), so just perform this iteration once more. |
206 | index -= 1; |
207 | } |
208 | } |
209 | } |
210 | |
211 | Map<Var, Range> InferRange(const Map<Var, PrimExpr>& vars_to_infer, const Array<Var>& ori_vars, |
212 | const Map<Var, Range>& ori_ranges) { |
213 | // The resulting ranges |
214 | Map<Var, Range> new_ranges; |
215 | |
216 | std::unordered_set<const VarNode*> ori_vset; |
217 | for (const Var& v : ori_vars) { |
218 | ori_vset.insert(v.get()); |
219 | } |
220 | |
221 | std::unordered_map<const VarNode*, IntSet> var_intsets; |
222 | for (const auto& p : ori_ranges) { |
223 | if (!ori_vset.count(p.first.get())) { |
224 | // First of all, fill the new ranges with outer variable ranges |
225 | new_ranges.Set(p.first, p.second); |
226 | } |
227 | // Convert original ranges to IntSets |
228 | var_intsets[p.first.get()] = IntSet::FromRange(p.second); |
229 | } |
230 | |
231 | // Infer ranges for the new variables and add them to the resulting ranges |
232 | for (const auto& p : vars_to_infer) { |
233 | const auto& var = p.first; |
234 | const auto& expr = p.second; |
235 | Range range = EvalSet(expr, var_intsets).CoverRange(Range()); |
236 | if (range.defined()) { |
237 | new_ranges.Set(var, range); |
238 | } |
239 | } |
240 | return new_ranges; |
241 | } |
242 | |
243 | // pretty print matrix equation |
244 | void DebugPrint(const std::vector<std::vector<int64_t>>& S, |
245 | const std::vector<std::vector<int64_t>>& V, const std::vector<PrimExpr>& V_inv_x, |
246 | const std::vector<PrimExpr>& rhs) { |
247 | std::cout << "S:\n" ; |
248 | for (size_t i = 0; i < S.size(); ++i) { |
249 | for (auto e : S[i]) { |
250 | std::cout << e << "\t" ; |
251 | } |
252 | std::cout << "\t->\t" << rhs[i]; |
253 | std::cout << "\n" ; |
254 | } |
255 | std::cout << "V:\n" ; |
256 | for (const auto& r : V) { |
257 | for (auto e : r) { |
258 | std::cout << e << "\t" ; |
259 | } |
260 | std::cout << "\n" ; |
261 | } |
262 | std::cout << "V_inv x:\n" << Array<PrimExpr>(V_inv_x); |
263 | std::cout << "\n" << std::endl; |
264 | } |
265 | |
266 | IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_solve) { |
267 | // m: # of equations |
268 | // n: # of variables |
269 | // we first construct A_{mxn} x_{nx1} = y_{mx1} |
270 | // then get Smith normal form of matrix A, |
271 | // S_{mxn} = U_{mxm} A_{mxn} V_{nxn} |
272 | // => U^{-1} S V^{-1} x = y |
273 | // S V^{-1} x = U y |
274 | std::vector<PrimExpr> Uy; // mx1 |
275 | std::vector<std::vector<int64_t>> S; // mxn |
276 | std::vector<std::vector<int64_t>> V; // nxn |
277 | std::vector<PrimExpr> V_inv_x; // V^{-1} x, nx1 |
278 | // Conditions we don't know what to do with |
279 | std::vector<PrimExpr> rest; |
280 | |
281 | Analyzer analyzer_problem; |
282 | analyzer_problem.Bind(system_to_solve->ranges); |
283 | |
284 | size_t num_vars = system_to_solve->variables.size(); |
285 | |
286 | // initialize V_{nxn} with identity matrix, |
287 | // initialize V^{-1} x as x |
288 | for (size_t i = 0; i < num_vars; ++i) { |
289 | V.emplace_back(num_vars); |
290 | V.back()[i] = 1; |
291 | V_inv_x.push_back(system_to_solve->variables[i]); |
292 | } |
293 | |
294 | // Transform formulas into rows of the matrix |
295 | // S_{mxn} V^{-1}_{nxn} x_{nx1} = U y, in which n is # of variables |
296 | // here we initialize S_{mxn} to be A, U to be identity matrix. |
297 | for (const PrimExpr& equation : system_to_solve->relations) { |
298 | if (const tir::EQNode* eq = equation.as<tir::EQNode>()) { |
299 | // a-b = sum_{i=0}^{n-1} variables[i] * coeff[i] + coeff[n] |
300 | Array<PrimExpr> coeffs = arith::DetectLinearEquation(analyzer_problem.Simplify(eq->a - eq->b), |
301 | system_to_solve->variables); |
302 | if (!coeffs.empty()) { |
303 | std::vector<int64_t> row; |
304 | for (size_t j = 0; j < coeffs.size() - 1; ++j) { |
305 | PrimExpr c = coeffs[j]; |
306 | if (const IntImmNode* ic = c.as<IntImmNode>()) { |
307 | row.push_back(ic->value); |
308 | } else { |
309 | // elements in matrix S V must be integers |
310 | // ignore equations that we cannot deal with. |
311 | LOG(WARNING) << "Cannot deal with non-integer coefficients, ignore equation " |
312 | << equation; |
313 | row.clear(); |
314 | break; |
315 | } |
316 | } |
317 | |
318 | if (!row.empty()) { |
319 | // S V^{-1} (a-b) = Uy |
320 | // V is identity for now |
321 | S.push_back(row); |
322 | Uy.push_back(-coeffs[coeffs.size() - 1]); |
323 | continue; |
324 | } |
325 | } |
326 | } |
327 | |
328 | // otherwise |
329 | rest.push_back(equation); |
330 | } |
331 | |
332 | // After diagonalizing, we have |
333 | // S_{mxn} is the Smith normal form (diagonal matrix) |
334 | // V_{nxn} is invertible |
335 | // V_inv_x is V^{-1} \times x |
336 | // Uy is U \times y |
337 | SmithNormalFormDiag(&S, &V, &V_inv_x, &Uy); |
338 | |
339 | Array<Var> new_vars; |
340 | Array<PrimExpr> new_relations; |
341 | Map<Var, PrimExpr> new_to_old_map; |
342 | Map<Var, PrimExpr> old_to_new_map; |
343 | |
344 | // Simplify right hand sides |
345 | for (PrimExpr r : Uy) { |
346 | r = analyzer_problem.Simplify(r); |
347 | } |
348 | |
349 | // Create the relations of the existence of a solution |
350 | for (size_t j = 0; j < S.size(); ++j) { |
351 | PrimExpr new_relation; |
352 | if (j >= num_vars || S[j][j] == 0) { |
353 | // The row of matrix is zero. A solution exists only if the Ub[j] is also zero |
354 | new_relation = (Uy[j] == 0); |
355 | } else { |
356 | // The diagonal element is non-zero. A solution exists only if the diagonal element |
357 | // is a divisor of the Ub[j] |
358 | new_relation = (floormod(Uy[j], std::abs(S[j][j])) == 0); |
359 | } |
360 | new_relation = analyzer_problem.Simplify(new_relation); |
361 | if (tir::is_const_int(new_relation, 0)) { |
362 | // unable to solve the system. |
363 | return IntConstraintsTransform(system_to_solve, |
364 | IntConstraints( |
365 | /*variables=*/{}, |
366 | /*ranges=*/{}, |
367 | /*relations=*/{tir::make_zero(DataType::Bool())}), |
368 | {}, {}); |
369 | } else if (!tir::is_const_int(new_relation, 1)) { |
370 | new_relations.push_back(new_relation); |
371 | } |
372 | } |
373 | |
374 | Array<PrimExpr> solution_for_V_inv_x; |
375 | // Now create new variables or directly solve the equations |
376 | // suppose the rank of A is r, aka r = # of non-zeros in S |
377 | // the solution of S_{mxn} V^{-1}_{nxn} x_{nx1} = U b |
378 | // is |
379 | // x = (pseudo-inverse of A) b + K_{(n)x(n-r)} z_{n-r} |
380 | // = V_{nxn} S^{-1}_{nxm} (Ub)_{mxn} + K_{(n)x(n-r)} z_{n-r} |
381 | // in which K is the right n-r columns of V, z is variable vector |
382 | // thus, |
383 | // V^{-1} x = S^{-1}_{nxm} (Ub)_{mxn} + |
384 | // [[0, ... 0]_{n-r}, ... [0, ..., 0], diag(1, ..., 1)_{(n-r)x(n-r)}] z_{n-r} |
385 | for (size_t j = 0; j < num_vars; ++j) { |
386 | if (j >= S.size() || S[j][j] == 0) { |
387 | // The j-th variable can take any integer value, create a tvm variable for it |
388 | PrimExpr to_old = analyzer_problem.Simplify(V_inv_x[j]); |
389 | std::string name_hint = "n" + std::to_string(new_vars.size()); |
390 | if (const VarNode* v_old = to_old.as<VarNode>()) { |
391 | name_hint += "_" + v_old->name_hint; |
392 | } |
393 | Var v = Var(name_hint, V_inv_x[j].dtype()); |
394 | solution_for_V_inv_x.push_back(v); |
395 | new_vars.push_back(v); |
396 | new_to_old_map.Set(v, to_old); |
397 | } else { |
398 | // The j-th variable is just a single value, don't create a tvm variable |
399 | // S^{-1}_{nxm} Uy_{mxn} |
400 | if (S[j][j] >= 0) { |
401 | PrimExpr a = tir::make_const(Uy[j].dtype(), S[j][j]); |
402 | solution_for_V_inv_x.push_back(analyzer_problem.Simplify(floordiv(Uy[j], a))); |
403 | } else { |
404 | // This is required because some simplifiers |
405 | // have problems with dividing by negative numbers |
406 | PrimExpr a = tir::make_const(Uy[j].dtype(), -S[j][j]); |
407 | solution_for_V_inv_x.push_back(analyzer_problem.Simplify(floordiv(-Uy[j], a))); |
408 | } |
409 | } |
410 | } |
411 | |
412 | // V V^{-1} x = x |
413 | for (size_t i = 0; i < num_vars; ++i) { |
414 | PrimExpr e = tir::make_zero(system_to_solve->variables[i].dtype()); |
415 | for (size_t j = 0; j < num_vars; ++j) { |
416 | e = e + tir::make_const(e.dtype(), V[i][j]) * solution_for_V_inv_x[j]; |
417 | } |
418 | e = analyzer_problem.Simplify(e); |
419 | old_to_new_map.Set(system_to_solve->variables[i], e); |
420 | } |
421 | |
422 | // The resulting ranges |
423 | Map<Var, Range> new_ranges = |
424 | InferRange(new_to_old_map, system_to_solve->variables, system_to_solve->ranges); |
425 | Analyzer analyzer_solution; |
426 | analyzer_solution.Bind(new_ranges); |
427 | |
428 | // We have to transform ranges of the old variables into relations over new variables because |
429 | // new ranges are not enough usually. |
430 | for (const auto& old_var : system_to_solve->variables) { |
431 | if (system_to_solve->ranges.find(old_var) != system_to_solve->ranges.end()) { |
432 | const Range& old_range = system_to_solve->ranges.at(old_var); |
433 | PrimExpr express_by_new_vars = old_to_new_map.at(old_var); |
434 | PrimExpr lower_cond = analyzer_solution.Simplify(old_range->min <= express_by_new_vars); |
435 | PrimExpr upper_cond = |
436 | analyzer_solution.Simplify(express_by_new_vars < old_range->min + old_range->extent); |
437 | if (!tir::is_const_int(lower_cond, 1)) { |
438 | new_relations.push_back(lower_cond); |
439 | } |
440 | if (!tir::is_const_int(upper_cond, 1)) { |
441 | new_relations.push_back(upper_cond); |
442 | } |
443 | } |
444 | } |
445 | |
446 | // Add the rest conditions |
447 | for (const PrimExpr& cond : rest) { |
448 | new_relations.push_back(Substitute(cond, old_to_new_map)); |
449 | } |
450 | |
451 | IntConstraints solution(new_vars, new_ranges, new_relations); |
452 | IntConstraintsTransform transform(system_to_solve, solution, old_to_new_map, new_to_old_map); |
453 | |
454 | return transform; |
455 | } |
456 | |
457 | TVM_REGISTER_GLOBAL("arith.SolveLinearEquations" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
458 | if (args.size() == 1) { |
459 | *ret = SolveLinearEquations(args[0]); |
460 | } else if (args.size() == 3) { |
461 | IntConstraints problem(args[0], args[1], args[2]); |
462 | *ret = SolveLinearEquations(problem); |
463 | } else { |
464 | LOG(FATAL) << "arith.SolveLinearEquations expects 1 or 3 arguments, gets " << args.size(); |
465 | } |
466 | }); |
467 | |
468 | } // namespace arith |
469 | } // namespace tvm |
470 | |