1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/arith/solve_linear_inequality.cc |
22 | * \brief Solve linear inequalities. |
23 | */ |
24 | #include <tvm/arith/analyzer.h> |
25 | #include <tvm/arith/int_solver.h> |
26 | #include <tvm/arith/pattern.h> |
27 | #include <tvm/runtime/data_type.h> |
28 | #include <tvm/runtime/registry.h> |
29 | #include <tvm/tir/analysis.h> |
30 | #include <tvm/tir/expr.h> |
31 | #include <tvm/tir/op.h> |
32 | #include <tvm/tir/stmt_functor.h> |
33 | |
34 | #include "int_operator.h" |
35 | |
36 | namespace tvm { |
37 | namespace arith { |
38 | |
39 | using namespace tvm::runtime; |
40 | using namespace tvm::tir; |
41 | |
42 | struct ExprLess { |
43 | bool operator()(const PrimExpr& l, const PrimExpr& r) const { |
44 | return CalculateExprComplexity(l) < CalculateExprComplexity(r); |
45 | } |
46 | }; |
47 | |
48 | void DebugPrint(const std::vector<PrimExpr>& current_ineq_set, |
49 | const std::vector<PrimExpr>& next_ineq_set, const std::vector<PrimExpr>& rest, |
50 | const std::vector<std::pair<int64_t, PrimExpr>>& coef_pos, |
51 | const std::vector<std::pair<int64_t, PrimExpr>>& coef_neg) { |
52 | std::cout << "Current ineq set:\n[" ; |
53 | for (auto& ineq : current_ineq_set) { |
54 | std::cout << ineq << ", " ; |
55 | } |
56 | std::cout << "]\n" ; |
57 | |
58 | std::cout << "Next ineq set:\n[" ; |
59 | for (auto& ineq : next_ineq_set) { |
60 | std::cout << ineq << ", " ; |
61 | } |
62 | std::cout << "]\n" ; |
63 | |
64 | std::cout << "coef_pos:\n[" ; |
65 | for (auto& coef : coef_pos) { |
66 | std::cout << "(" << coef.first << ", " << coef.second << "), " ; |
67 | } |
68 | std::cout << "]\n" ; |
69 | |
70 | std::cout << "coef_neg:\n[" ; |
71 | for (auto& coef : coef_neg) { |
72 | std::cout << "(" << coef.first << ", " << coef.second << "), " ; |
73 | } |
74 | std::cout << "]\n" ; |
75 | } |
76 | |
77 | /*! |
78 | * \brief normalize to the form `expr <= 0` |
79 | */ |
80 | class NormalizeComparisons : public ExprMutator { |
81 | public: |
82 | PrimExpr VisitExpr_(const EQNode* op) override { return Make<EQ>(op->a, op->b); } |
83 | PrimExpr VisitExpr_(const NENode* op) override { return Make<NE>(op->a, op->b); } |
84 | PrimExpr VisitExpr_(const LTNode* op) override { return Make<LT>(op->a, op->b); } |
85 | PrimExpr VisitExpr_(const LENode* op) override { return Make<LE>(op->a, op->b); } |
86 | PrimExpr VisitExpr_(const GTNode* op) override { return Make<LT>(op->b, op->a); } |
87 | PrimExpr VisitExpr_(const GENode* op) override { return Make<LE>(op->b, op->a); } |
88 | |
89 | private: |
90 | template <class T> |
91 | PrimExpr Make(const PrimExpr& a, const PrimExpr& b) { |
92 | // rewrite LT to LE for ints |
93 | if (std::is_same<T, LT>::value && (a.dtype().is_int() || a.dtype().is_uint())) { |
94 | return LE(analyzer_.Simplify(a - b + 1), make_zero(a.dtype())); |
95 | } |
96 | return T(analyzer_.Simplify(a - b), make_zero(a.dtype())); |
97 | } |
98 | arith::Analyzer analyzer_; |
99 | }; |
100 | |
101 | void AddInequality(std::vector<PrimExpr>* inequality_set, const PrimExpr& new_ineq, |
102 | Analyzer* analyzer) { |
103 | if (analyzer->CanProve(new_ineq) || |
104 | std::find_if(inequality_set->begin(), inequality_set->end(), [&](const PrimExpr& e) { |
105 | return StructuralEqual()(e, new_ineq); |
106 | }) != inequality_set->end()) { |
107 | // redundant: follows from the vranges |
108 | // or has already been added |
109 | return; |
110 | } |
111 | if (const LENode* new_le = new_ineq.as<LENode>()) { |
112 | for (auto iter = inequality_set->begin(); iter != inequality_set->end();) { |
113 | const LENode* le = iter->as<LENode>(); |
114 | if (le && analyzer->CanProve(new_le->a - le->a <= 0)) { |
115 | return; |
116 | } else if (le && analyzer->CanProve(le->a - new_le->a <= 0)) { |
117 | iter = inequality_set->erase(iter); |
118 | } else { |
119 | ++iter; |
120 | } |
121 | } |
122 | } |
123 | |
124 | inequality_set->push_back(new_ineq); |
125 | } |
126 | |
127 | void ClassifyByPolarity(const Var& var, const std::vector<PrimExpr>& current_ineq_set, |
128 | std::vector<PrimExpr>* next_ineq_set, std::vector<PrimExpr>* rest, |
129 | std::vector<std::pair<int64_t, PrimExpr>>* coef_pos, |
130 | std::vector<std::pair<int64_t, PrimExpr>>* coef_neg, Analyzer* analyzer) { |
131 | // Take formulas from current_ineq_set and classify them according to polarity wrt var |
132 | // and store to coef_pos and coef_neg respectively. |
133 | for (const PrimExpr& ineq : current_ineq_set) { |
134 | if (const LENode* le = ineq.as<LENode>()) { |
135 | Array<PrimExpr> coef = arith::DetectLinearEquation(le->a, {var}); |
136 | if (!coef.empty() && is_const_int(coef[0])) { |
137 | int64_t coef0 = *as_const_int(coef[0]); |
138 | if (coef0 == 0) { |
139 | // zero polarity, straight to next_ineq_set |
140 | AddInequality(next_ineq_set, ineq, analyzer); |
141 | } else if (coef0 > 0) { |
142 | coef_pos->push_back({coef0, coef[1]}); |
143 | } else if (coef0 < 0) { |
144 | coef_neg->push_back({coef0, coef[1]}); |
145 | } |
146 | continue; |
147 | } |
148 | } else if (const EQNode* eq = ineq.as<EQNode>()) { |
149 | Array<PrimExpr> coef = arith::DetectLinearEquation(eq->a, {var}); |
150 | if (!coef.empty() && is_const_int(coef[0])) { |
151 | int64_t coef0 = *as_const_int(coef[0]); |
152 | if (coef0 == 0) { |
153 | // zero polarity, straight to next_ineq_set |
154 | AddInequality(next_ineq_set, ineq, analyzer); |
155 | } else if (coef0 > 0) { |
156 | // Equalities may be considered as pairs of two inequalities |
157 | coef_pos->push_back({coef0, coef[1]}); |
158 | coef_neg->push_back({-coef0, -coef[1]}); |
159 | } else if (coef0 < 0) { |
160 | coef_pos->push_back({-coef0, -coef[1]}); |
161 | coef_neg->push_back({coef0, coef[1]}); |
162 | } |
163 | continue; |
164 | } |
165 | } |
166 | |
167 | // if nothing worked, put it in rest |
168 | rest->push_back(ineq); |
169 | } |
170 | } |
171 | |
172 | void MoveEquality(std::vector<PrimExpr>* upper_bounds, std::vector<PrimExpr>* lower_bounds, |
173 | std::vector<PrimExpr>* equalities) { |
174 | // those exist in both upper & lower bounds will be moved to equalities |
175 | for (auto ub = upper_bounds->begin(); ub != upper_bounds->end();) { |
176 | auto lb = std::find_if(lower_bounds->begin(), lower_bounds->end(), |
177 | [&](const PrimExpr& e) { return StructuralEqual()(e, *ub); }); |
178 | if (lb != lower_bounds->end()) { |
179 | equalities->push_back(*lb); |
180 | lower_bounds->erase(lb); |
181 | ub = upper_bounds->erase(ub); |
182 | } else { |
183 | ++ub; |
184 | } |
185 | } |
186 | } |
187 | |
188 | PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_to_solve) { |
189 | arith::Analyzer analyzer; |
190 | analyzer.Bind(system_to_solve->ranges); |
191 | |
192 | // The algorithm consists in doing the following things for each variable v |
193 | // - Take formulas from `current_ineq_set_to_solve` and |
194 | // classify them according to polarity wrt v. |
195 | // - Combine each formula of positive polarity (wrt v) |
196 | // with each formula of negative polarity. |
197 | // - Put the resulting combinations into `next_ineq_set_to_solve` |
198 | // along with unclassifiable formulas. |
199 | // - Replace `current_ineq_set_to_solve` with `next_ineq_set_to_solve` |
200 | // and move to the next variable. |
201 | |
202 | // normalized inequality |
203 | std::vector<PrimExpr> current_ineq_set_to_solve; |
204 | std::vector<PrimExpr> next_ineq_set_to_solve; |
205 | // A vector of pairs (c, e), c > 0, representing formulas of the form c*v + e <= 0 |
206 | std::vector<std::pair<int64_t, PrimExpr>> coef_pos; |
207 | // A vector of pairs (c, e), c < 0, representing formulas of the form c*v + e <= 0 |
208 | std::vector<std::pair<int64_t, PrimExpr>> coef_neg; |
209 | |
210 | // formulas we don't know what to do with |
211 | std::vector<PrimExpr> rest; |
212 | |
213 | // Simplify each inequality into the form `expr <= 0` and add to current formulas |
214 | for (const PrimExpr& ineq : system_to_solve->relations) { |
215 | AddInequality(¤t_ineq_set_to_solve, |
216 | NormalizeComparisons()(analyzer.Simplify(ineq, kSimplifyRewriteCanonicalRewrite)), |
217 | &analyzer); |
218 | } |
219 | |
220 | Map<Var, IntGroupBounds> res_bounds; |
221 | for (const Var& v : system_to_solve->variables) { |
222 | ICHECK(!res_bounds.count(v)) |
223 | << "Variable " << v |
224 | << " appears more than one time in the `variables` which might be a bug" ; |
225 | |
226 | next_ineq_set_to_solve.clear(); |
227 | coef_pos.clear(); |
228 | coef_neg.clear(); |
229 | |
230 | // Add bounds from vranges |
231 | if (system_to_solve->ranges.count(v)) { |
232 | const Range& range = system_to_solve->ranges[v]; |
233 | PrimExpr range_lbound = analyzer.Simplify(range->min, kSimplifyRewriteCanonicalRewrite); |
234 | PrimExpr range_ubound = |
235 | analyzer.Simplify(range->min + range->extent - 1, kSimplifyRewriteCanonicalRewrite); |
236 | coef_neg.push_back({-1, range_lbound}); |
237 | coef_pos.push_back({1, -range_ubound}); |
238 | } |
239 | |
240 | ClassifyByPolarity(v, current_ineq_set_to_solve, &next_ineq_set_to_solve, &rest, &coef_pos, |
241 | &coef_neg, &analyzer); |
242 | |
243 | // Combine each positive inequality with each negative one (by adding them together) |
244 | int64_t gcd_x, gcd_y; |
245 | for (const auto& pos : coef_pos) { |
246 | for (const auto& neg : coef_neg) { |
247 | auto first_gcd = ExtendedEuclidean(pos.first, -neg.first, &gcd_x, &gcd_y); |
248 | PrimExpr c_pos = make_const(v.dtype(), neg.first / first_gcd); |
249 | PrimExpr c_neg = make_const(v.dtype(), pos.first / first_gcd); |
250 | // eliminate the current variable |
251 | PrimExpr new_lhs = c_neg * neg.second - c_pos * pos.second; |
252 | PrimExpr new_ineq = LE(new_lhs, make_zero(pos.second.dtype())); |
253 | // we need rewrite_simplify -> canonical_simplify -> rewrite_simplify |
254 | // to help simplify things like (((y + 10) - (-1*(y - 20))) <= 0) => y - 5 <= 0 |
255 | // with steps = 2 it's (y*2) - 10 <= 0 |
256 | new_ineq = |
257 | NormalizeComparisons()(analyzer.Simplify(new_ineq, kSimplifyRewriteCanonicalRewrite)); |
258 | AddInequality(&next_ineq_set_to_solve, new_ineq, &analyzer); |
259 | } |
260 | } |
261 | |
262 | // Now we have to generate resulting (in)equalities for the variable v |
263 | |
264 | // Find the common denominator in a sense |
265 | // We will generate formulas of the form coef_lcm*v <= bound |
266 | int64_t coef_lcm = 1; |
267 | for (const auto& pos : coef_pos) { |
268 | coef_lcm = LeastCommonMultiple(coef_lcm, pos.first); |
269 | } |
270 | for (const auto& neg : coef_neg) { |
271 | coef_lcm = LeastCommonMultiple(coef_lcm, -neg.first); |
272 | } |
273 | |
274 | // The resulting lower and upper bounds |
275 | std::vector<PrimExpr> upper_bounds; |
276 | std::vector<PrimExpr> lower_bounds; |
277 | upper_bounds.reserve(coef_pos.size()); |
278 | lower_bounds.reserve(coef_neg.size()); |
279 | |
280 | for (const auto& pos : coef_pos) { |
281 | PrimExpr bound = make_const(v.dtype(), -coef_lcm / pos.first) * pos.second; |
282 | bound = analyzer.Simplify(bound, kSimplifyRewriteCanonicalRewrite); |
283 | // Don't add if any of the existing bounds is better |
284 | if (std::any_of(upper_bounds.begin(), upper_bounds.end(), |
285 | [&bound, &analyzer](const PrimExpr& o) { |
286 | return analyzer.CanProve(o - bound <= 0); |
287 | })) { |
288 | continue; |
289 | } |
290 | // Erase all worse bounds |
291 | for (auto iter = upper_bounds.begin(); iter != upper_bounds.end();) { |
292 | if (analyzer.CanProve(*iter - bound >= 0)) { |
293 | iter = upper_bounds.erase(iter); |
294 | } else { |
295 | ++iter; |
296 | } |
297 | } |
298 | // Add the upper bound |
299 | upper_bounds.push_back(bound); |
300 | } |
301 | for (const auto& neg : coef_neg) { |
302 | PrimExpr bound = make_const(v.dtype(), -coef_lcm / neg.first) * neg.second; |
303 | bound = analyzer.Simplify(bound, kSimplifyRewriteCanonicalRewrite); |
304 | // Don't add if any of the existing bounds is better |
305 | if (std::any_of(lower_bounds.begin(), lower_bounds.end(), |
306 | [&bound, &analyzer](const PrimExpr& o) { |
307 | return analyzer.CanProve(o - bound >= 0); |
308 | })) { |
309 | continue; |
310 | } |
311 | // Erase all worse bounds |
312 | for (auto iter = lower_bounds.begin(); iter != lower_bounds.end();) { |
313 | if (analyzer.CanProve(*iter - bound <= 0)) { |
314 | iter = lower_bounds.erase(iter); |
315 | } else { |
316 | ++iter; |
317 | } |
318 | } |
319 | // Add the lower bound |
320 | lower_bounds.push_back(bound); |
321 | } |
322 | |
323 | std::vector<PrimExpr> equal; |
324 | equal.reserve(std::min(upper_bounds.size(), lower_bounds.size())); |
325 | MoveEquality(&upper_bounds, &lower_bounds, &equal); |
326 | std::vector<PrimExpr> equal_list(equal.begin(), equal.end()); |
327 | std::sort(equal_list.begin(), equal_list.end(), ExprLess()); |
328 | |
329 | // Write it to the result. |
330 | IntGroupBounds bnds(make_const(v.dtype(), coef_lcm), |
331 | Array<PrimExpr>(lower_bounds.begin(), lower_bounds.end()), |
332 | Array<PrimExpr>(equal_list.begin(), equal_list.end()), |
333 | Array<PrimExpr>(upper_bounds.begin(), upper_bounds.end())); |
334 | res_bounds.Set(v, bnds); |
335 | |
336 | std::swap(current_ineq_set_to_solve, next_ineq_set_to_solve); |
337 | } |
338 | |
339 | // Everything that is left goes to res.relations |
340 | Array<PrimExpr> other_conditions; |
341 | for (const PrimExpr& e : current_ineq_set_to_solve) { |
342 | PrimExpr e_simp = analyzer.Simplify(e, kSimplifyRewriteCanonicalRewrite); |
343 | if (is_const_int(e_simp, 0)) { |
344 | // contradiction detected |
345 | other_conditions = {const_false()}; |
346 | break; |
347 | } else if (is_const_int(e_simp, 1)) { |
348 | continue; |
349 | } else { |
350 | other_conditions.push_back(e_simp); |
351 | } |
352 | } |
353 | |
354 | for (const PrimExpr& e : rest) { |
355 | other_conditions.push_back(e); |
356 | } |
357 | |
358 | return {res_bounds, other_conditions}; |
359 | } |
360 | |
361 | #ifdef _MSC_VER |
362 | #pragma optimize("g", off) |
363 | #endif |
364 | IntConstraints SolveInequalitiesToRange(const IntConstraints& inequalities) { |
365 | // Resulting ranges will contain ranges for the new variables and for the variables that are |
366 | // not in the inequalities->variables but are in inequalities->ranges |
367 | // It will be useful when solving Jacobian axes jac_xxx) |
368 | Map<Var, Range> res_ranges; |
369 | // we get a set of equality, lower, upper bound of each variable. |
370 | auto solved_system = SolveLinearInequalities(inequalities); |
371 | |
372 | Map<Var, IntGroupBounds> solved_bounds = solved_system.first; |
373 | Array<PrimExpr> solved_other_relations = solved_system.second; |
374 | |
375 | Array<PrimExpr> res_relations; |
376 | |
377 | // this keeps being updated during determining the range of each variable. |
378 | Map<Var, Range> vranges; |
379 | for (std::pair<Var, Range> vr : inequalities->ranges) { |
380 | vranges.Set(vr.first, vr.second); |
381 | } |
382 | |
383 | // We process variables in the reverse direction to start with the most independent one. |
384 | // This order is needed to compute new ranges. |
385 | for (auto it = inequalities->variables.rbegin(); it != inequalities->variables.rend(); ++it) { |
386 | arith::Analyzer analyzer; |
387 | analyzer.Bind(vranges); |
388 | |
389 | const Var& var = *it; |
390 | ICHECK(solved_bounds.count(var)); |
391 | auto bnd = solved_bounds[var]; |
392 | if (is_one(bnd->coef) && !bnd->equal.empty()) { |
393 | // There is an equation of the form `v == expr`, so this variable can be completely removed. |
394 | // Note that we use the 0-th expression because they are ordered by complexity, |
395 | // so it must be the simplest one. |
396 | // The MSVC compiler optimization must be disabled for the expression `bnd->equal[0]` which |
397 | // triggers an internal compiler error. |
398 | Range best_range(bnd->equal[0], |
399 | analyzer.Simplify(bnd->equal[0] + 1, kSimplifyRewriteCanonicalRewrite)); |
400 | res_ranges.Set(var, best_range); |
401 | vranges.Set(var, best_range); |
402 | } else { |
403 | if (vranges.count(var) > 0) { |
404 | bnd = bnd + vranges[var]; |
405 | } |
406 | |
407 | auto best_range = bnd.FindBestRange(vranges); |
408 | |
409 | if (best_range.defined()) { |
410 | if (analyzer.CanProveGreaterEqual(-best_range->extent, 0)) { |
411 | // range.extent <= 0 implies the input inequality system is unsolvable |
412 | return IntConstraints(/*variables=*/{}, /*ranges=*/{}, |
413 | /*relations=*/{tir::make_zero(DataType::Bool())}); |
414 | } |
415 | res_ranges.Set(var, best_range); |
416 | vranges.Set(var, best_range); |
417 | } |
418 | } |
419 | } |
420 | |
421 | // Add the original conditions to the resulting conditions |
422 | arith::Analyzer analyzer; |
423 | analyzer.Bind(vranges); |
424 | for (const PrimExpr& old_cond : |
425 | AsConditions(inequalities->variables, solved_bounds, solved_other_relations)) { |
426 | if (!analyzer.CanProve(old_cond)) { |
427 | // those not represented in vranges (res_ranges) |
428 | res_relations.push_back(old_cond); |
429 | } |
430 | } |
431 | |
432 | IntConstraints system(inequalities->variables, res_ranges, res_relations); |
433 | |
434 | return system; |
435 | } |
436 | #ifdef _MSC_VER |
437 | #pragma optimize("g", on) |
438 | #endif |
439 | |
440 | IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& inequalities) { |
441 | // Resulting ranges will contain ranges for the new variables and for the variables that are |
442 | // not in the inequalities->variables but are in inequalities->ranges (jac_xxx) |
443 | Map<Var, Range> res_ranges; |
444 | // we get a set of equality, lower, upper bound of each variable. |
445 | auto solved_system = SolveLinearInequalities(inequalities); |
446 | Map<Var, IntGroupBounds> solved_bounds = solved_system.first; |
447 | Array<PrimExpr> solved_other_relations = solved_system.second; |
448 | |
449 | arith::Analyzer analyzer; |
450 | |
451 | Map<Var, PrimExpr> res_src_to_dst; |
452 | Map<Var, PrimExpr> res_dst_to_src; |
453 | Array<Var> res_variables; |
454 | Array<PrimExpr> res_relations; |
455 | |
456 | // this keeps being updated during determining the range of each variable. |
457 | Map<Var, Range> vranges; |
458 | for (std::pair<Var, Range> vr : inequalities->ranges) { |
459 | vranges.Set(vr.first, vr.second); |
460 | } |
461 | analyzer.Bind(vranges); |
462 | |
463 | // We process variables in the reverse direction to start with the most independent one. |
464 | // This order is needed to compute new ranges. |
465 | for (auto it = inequalities->variables.rbegin(); it != inequalities->variables.rend(); ++it) { |
466 | const Var& var = *it; |
467 | auto bnd = solved_bounds[var]; |
468 | // Note that we replace old vars with new ones |
469 | bnd = bnd.Substitute(res_src_to_dst); |
470 | |
471 | if (is_one(bnd->coef) && !bnd->equal.empty()) { |
472 | // There is an equation of the form `v == expr`, |
473 | // so this variable can be completely removed. |
474 | // Note that we use the 0-th expression because they are ordered by complexity, |
475 | // so it must be the simplest one. |
476 | res_src_to_dst.Set(var, bnd->equal[0]); |
477 | } else { |
478 | if (vranges.count(var) > 0) { |
479 | bnd = bnd + vranges[var]; |
480 | } |
481 | |
482 | auto best_range = bnd.FindBestRange(vranges); |
483 | |
484 | Var new_var = var.copy_with_suffix(".shifted" ); |
485 | if (!best_range.defined()) { |
486 | res_src_to_dst.Set(var, var); |
487 | res_dst_to_src.Set(var, var); |
488 | res_variables.push_back(var); |
489 | } else if (is_const_int(best_range->extent, 1)) { |
490 | // Don't create an itervar, just replace it everywhere with its min |
491 | res_src_to_dst.Set(var, best_range->min); |
492 | } else if (analyzer.CanProveGreaterEqual(-best_range->extent, 0)) { |
493 | // range.extent <= 0 implies the input inequality system is unsolvable |
494 | return IntConstraintsTransform(inequalities, |
495 | IntConstraints( |
496 | /*variables=*/{}, |
497 | /*ranges=*/{}, |
498 | /*relations=*/{tir::make_zero(DataType::Bool())}), |
499 | {}, {}); |
500 | } else { |
501 | // created new_var starts from 0 |
502 | res_src_to_dst.Set(var, new_var + best_range->min); |
503 | // Note that we are substituting old with new, so best_range contains new var, |
504 | // that is we have to substitute new with old in best_range here |
505 | res_dst_to_src.Set(new_var, |
506 | analyzer.Simplify(var - Substitute(best_range->min, res_dst_to_src))); |
507 | |
508 | // Add the new var to the resulting axis |
509 | auto range = Range(make_zero(new_var.dtype()), best_range->extent); |
510 | res_variables.push_back(new_var); |
511 | res_ranges.Set(new_var, range); |
512 | |
513 | vranges.Set(new_var, range); |
514 | analyzer.Bind(new_var, range); |
515 | } |
516 | } |
517 | } |
518 | |
519 | // Add the original conditions (with variables substituted) to the resulting conditions |
520 | for (const PrimExpr& old_cond : |
521 | AsConditions(inequalities->variables, solved_bounds, solved_other_relations)) { |
522 | PrimExpr new_cond = analyzer.Simplify(Substitute(old_cond, res_src_to_dst)); |
523 | if (!is_const_int(new_cond, 1)) { |
524 | // those not represented in vranges (res_ranges) |
525 | res_relations.push_back(new_cond); |
526 | } |
527 | } |
528 | |
529 | // Reverse the axis so that it matches the order of the original variables |
530 | res_variables = Array<Var>(res_variables.rbegin(), res_variables.rend()); |
531 | |
532 | IntConstraints new_inequalities(res_variables, res_ranges, res_relations); |
533 | IntConstraintsTransform transform(inequalities, new_inequalities, res_src_to_dst, res_dst_to_src); |
534 | |
535 | return transform; |
536 | } |
537 | |
538 | TVM_REGISTER_GLOBAL("arith.SolveInequalitiesAsCondition" ) |
539 | .set_body([](TVMArgs args, TVMRetValue* ret) { |
540 | IntConstraints problem; |
541 | PartialSolvedInequalities ret_ineq; |
542 | if (args.size() == 1) { |
543 | problem = args[0]; |
544 | ret_ineq = SolveLinearInequalities(problem); |
545 | } else if (args.size() == 3) { |
546 | problem = IntConstraints(args[0], args[1], args[2]); |
547 | ret_ineq = SolveLinearInequalities(problem); |
548 | } else { |
549 | LOG(FATAL) << "arith.SolveInequalitiesAsCondition expects 1 or 3 arguments, gets " |
550 | << args.size(); |
551 | } |
552 | *ret = AsConditions(problem->variables, ret_ineq.first, ret_ineq.second); |
553 | }); |
554 | |
555 | TVM_REGISTER_GLOBAL("arith.SolveInequalitiesToRange" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
556 | if (args.size() == 1) { |
557 | *ret = SolveInequalitiesToRange(args[0]); |
558 | } else if (args.size() == 3) { |
559 | IntConstraints problem(args[0], args[1], args[2]); |
560 | *ret = SolveInequalitiesToRange(problem); |
561 | } else { |
562 | LOG(FATAL) << "arith.SolveInequalitiesToRange expects 1 or 3 arguments, gets " << args.size(); |
563 | } |
564 | }); |
565 | |
566 | TVM_REGISTER_GLOBAL("arith.SolveInequalitiesDeskewRange" ) |
567 | .set_body([](TVMArgs args, TVMRetValue* ret) { |
568 | if (args.size() == 1) { |
569 | *ret = SolveInequalitiesDeskewRange(args[0]); |
570 | } else if (args.size() == 3) { |
571 | IntConstraints problem(args[0], args[1], args[2]); |
572 | *ret = SolveInequalitiesDeskewRange(problem); |
573 | } else { |
574 | LOG(FATAL) << "arith.SolveInequalitiesDeskewRange expects 1 or 3 arguments, gets " |
575 | << args.size(); |
576 | } |
577 | }); |
578 | |
579 | } // namespace arith |
580 | } // namespace tvm |
581 | |