1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | /*! |
20 | * \file tvm/arith/transitive_comparison_analyzer.cc |
21 | */ |
22 | |
23 | #include <tvm/arith/analyzer.h> |
24 | #include <tvm/tir/analysis.h> |
25 | #include <tvm/tir/expr.h> |
26 | |
27 | #include <optional> |
28 | #include <vector> |
29 | |
30 | #include "constraint_extract.h" |
31 | #include "pattern_match.h" |
32 | |
33 | namespace tvm { |
34 | namespace arith { |
35 | |
36 | using namespace tir; |
37 | |
38 | class TransitiveComparisonAnalyzer::Impl { |
39 | public: |
40 | /* \brief Using previously specified knowns, compare the expressions provided |
41 | * |
42 | * \param lhs The left-hand side of the comparison |
43 | * |
44 | * \param rhs The right-hand side of the comparison |
45 | * |
46 | * \param propagate_inequalities If true, attempt to find a sequence |
47 | * of transitive inequalities that allow the lhs and rhs to be |
48 | * compared. If false, only use the known comparison that have been |
49 | * directly provided. Using `propagate_inequalities = false` is |
50 | * roughly equivalent to comparing against all known values with |
51 | * `ExprDeepEqual`, but also allowing for constant offsets on either |
52 | * side of the inequality. |
53 | * |
54 | * \return The most specific result that can be proven about the |
55 | * comparison. If nothing can be proven, returns kUnknown. |
56 | */ |
57 | CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs, |
58 | bool propagate_inequalities = true) const; |
59 | |
60 | /*! \brief Bind a variable as being equal to a known expression |
61 | * |
62 | * \param var The variable of interest. |
63 | * \param expr The bound expression |
64 | * \param allow_override Whether to allow override of existing information. |
65 | */ |
66 | void Bind(const tir::Var& var, const PrimExpr& expr, bool allow_override = false); |
67 | |
68 | /*! \brief Bind a variable as being within a specified range |
69 | * |
70 | * \param var The variable of interest. |
71 | * \param range The known range |
72 | * \param allow_override Whether to allow override of existing information. |
73 | */ |
74 | void Bind(const tir::Var& var, const Range& expr, bool allow_override = false); |
75 | |
76 | /*! |
77 | * \brief Update the internal state to enter constraint. |
78 | * \param constraint A constraint expression. |
79 | * |
80 | * \return An exit function that must be called to cleanup. May be |
81 | * `nullptr`, if no cleanup is required. |
82 | */ |
83 | std::function<void()> EnterConstraint(const PrimExpr& expr); |
84 | |
85 | private: |
86 | /* \brief Internal representation of a PrimExpr |
87 | * |
88 | * The Key enum serves two purposes. |
89 | * |
90 | * 1. Providing efficiency, as compared to a PrimExpr. Two keys are |
91 | * equal if and only if the corresponding PrimExprs would satisfy |
92 | * ExprDeepEqual. This allows two expressions to be checked for |
93 | * equivalency, without requiring a call to ExprDeepEqual for |
94 | * each comparison. |
95 | * |
96 | * 2. Providing type-safety, as compared to using `size_t` directly. |
97 | * Requiring an explicit conversion from an integer to a Key |
98 | * prevents accidental comparisons, especially if both loop |
99 | * iterators and Keys are used in the same scope. |
100 | * |
101 | * A Key should only be obtained using the methods `ExprToKey` and |
102 | * `ExprToPreviousKey`. |
103 | */ |
104 | enum class Key : size_t {}; |
105 | |
106 | /*! \brief Convert an expression to internal representation |
107 | * |
108 | * If the expression has previously been converted to the internal |
109 | * representation, returns the same Key as has been used previously. |
110 | * Otherwise, generate and return a new Key. |
111 | * |
112 | * \param expr The PrimExpr to be converted |
113 | * |
114 | * \returns The Key representing the expression |
115 | * |
116 | * \see ExprToPreviousKey |
117 | */ |
118 | Key ExprToKey(const PrimExpr& expr); |
119 | |
120 | /*! \brief Convert an expression to internal representation |
121 | * |
122 | * If the expression has previously been converted to the internal |
123 | * representation, returns the same Key as has been used previously. |
124 | * Otherwise, return `std::nullopt`. |
125 | * |
126 | * \param expr The PrimExpr to be converted |
127 | * |
128 | * \returns The Key representing the expression, if one exists. |
129 | * |
130 | * \see ExprToKey |
131 | */ |
132 | std::optional<Key> ExprToPreviousKey(const PrimExpr& expr) const; |
133 | |
134 | /*! \brief The mapping from expression to Key |
135 | * |
136 | * Should not be used directly. Instead, use the helper functions |
137 | * `ExprToKey` and `ExprToPreviousKey`. |
138 | * |
139 | * \see ExprToKey |
140 | * \see ExprToPreviousKey |
141 | */ |
142 | std::unordered_map<PrimExpr, Key, StructuralHash, StructuralEqual> expr_to_key; |
143 | |
144 | /*! \brief Internal representation of a comparison operator */ |
145 | struct Comparison { |
146 | /*! \brief Construct a comparison that represents `lhs OP rhs + |
147 | * offset`, where the operation is specified by the CompareResult. |
148 | */ |
149 | Comparison(Key lhs, Key rhs, int64_t offset, CompareResult result); |
150 | |
151 | /*! \brief Utility function to validate that all GT and LT results |
152 | * have been normalized out |
153 | */ |
154 | bool IsNormalized() const; |
155 | |
156 | /*! \brief Move the specified expression to the LHS. |
157 | * |
158 | * \param new_lhs The argument that should be moved to the LHS of the |
159 | * comparison. |
160 | * |
161 | * \return If possible, returns a comparison that is equivalent to |
162 | * the current comparison, but with the specified LHS. If not |
163 | * possible, returns nullopt. |
164 | */ |
165 | std::optional<Comparison> WithLHS(Key new_lhs) const; |
166 | |
167 | /*! \brief Create the negation of the current comparison */ |
168 | Comparison Negated() const; |
169 | |
170 | /*! \brief Check the this comparison implies |
171 | * |
172 | * Returns true if this comparison being true implies that the |
173 | * other comparison must also be true. Returns false if the other |
174 | * comparison cannot be shown to be true. |
175 | */ |
176 | bool Implies(const Comparison& other) const; |
177 | |
178 | // The LHS of the comparison |
179 | Key lhs_; |
180 | |
181 | // The RHS of the comparison, not including any constant offset. |
182 | Key rhs_; |
183 | |
184 | // Additive offset on rhs |
185 | int64_t offset_{0}; |
186 | |
187 | // The comparison operator. |
188 | CompareResult result_{CompareResult::kInconsistent}; |
189 | }; |
190 | |
191 | /*! \brief Generate a Comparison representing the given expression */ |
192 | std::optional<Comparison> FromExpr(const PrimExpr& expr); |
193 | |
194 | /*! \brief Utility function used by Bind and EnterConstraint |
195 | * |
196 | * \param expr The comparison expression, to be converted into |
197 | * internal Comparison objects. |
198 | * |
199 | * \param vec The vector to which the Comparison objects should be |
200 | * appended. |
201 | */ |
202 | void AddKnown(const PrimExpr& expr, std::vector<Comparison>* vec); |
203 | |
204 | /*! Collect known comparisons between LHS and RHS, without propagation |
205 | * |
206 | * Allows the internal representation to handle any constant |
207 | * offsets, without searching for a sequence of inequalities. |
208 | * |
209 | * \param lhs_key The left-hand side of the comparison |
210 | * |
211 | * \param rhs_key The right-hand side of the comparison |
212 | * |
213 | * \returns A subset of `knowns_` and `scoped_knowns_`, filtered to |
214 | * only include comparisons between `lhs_key` and `rhs_key`, |
215 | * normalized such that `lhs_key` is on the left-hand side. |
216 | */ |
217 | std::vector<Comparison> CollectDirectComparisons(Key lhs_key, Key rhs_key) const; |
218 | |
219 | /*! Collect known comparisons between LHS and RHS, with propagation |
220 | * |
221 | * \param lhs_key The left-hand side of the comparison |
222 | * |
223 | * \param rhs_key The right-hand side of the comparison |
224 | * |
225 | * \returns All comparisons between `lhs_key` and `rhs_key`, |
226 | * including the explicitly-provided comparisons in `knowns_` and |
227 | * `scoped_knowns_`, and comparisons provable through a series of |
228 | * comparisons through other values. All comparisons returned are |
229 | * between `lhs_key` and `rhs_key`, and are normalized such that |
230 | * `lhs_key` is on the left-hand side. |
231 | */ |
232 | std::vector<Comparison> CollectIndirectComparisons(Key lhs_key, Key rhs_key) const; |
233 | |
234 | /*! \brief Internal function used by CollectIndirectComparisons |
235 | * |
236 | * Perform a depth-first search through the space of known |
237 | * expressions, starting at the LHS of a comparison. In this |
238 | * search, each expression is a node of a graph, and each known |
239 | * comparison is an edge of the graph. |
240 | * |
241 | * For example, suppose we have previous knowns of (A<=B), (B<=C+1) |
242 | * and (C<=D-5). The expressions [A,B,C,D] are the nodes of the |
243 | * search space. Each comparison is an edge connecting two |
244 | * expressions, such as (B<=C+1) connecting the expressions B and D. |
245 | * If we are attempting to compare expressions A and D, a search |
246 | * starting at expression A could follow each edge until reaching |
247 | * expression D, then combine the comparisons that compose the path |
248 | * into the expression A<=D-4. |
249 | * |
250 | * \param lhs_key The left-hand side of the comparison |
251 | * |
252 | * \param rhs_key The right-hand side of the comparison |
253 | * |
254 | * \returns A vector of comparisons between the two expressions. |
255 | */ |
256 | std::vector<Comparison> DFSFromLHS(Key lhs_key, Key rhs_key) const; |
257 | |
258 | /*! \brief Combine a set of comparisons that share a LHS and RHS |
259 | * |
260 | * \param lhs_to_rhs The comparisons to merge. These should all |
261 | * have the same LHS and RHS. This parameter will typically be the |
262 | * result from `CollectDirectComparisons` or |
263 | * `CollectIndirectComparisons`. |
264 | * |
265 | * \param offset The constant offset in the comparison being proven. |
266 | * This is extracted from any additive/subtractive constants in the |
267 | * `PrimExpr` arguments to `TryCompare`. |
268 | * |
269 | * \returns The possible comparisons between LHS and RHS provided |
270 | * inequalities. |
271 | */ |
272 | CompareResult MergeComparisons(const std::vector<Comparison>& lhs_to_rhs, int64_t offset) const; |
273 | |
274 | /*! \brief Previous Range bindings |
275 | * |
276 | * Tracked separatedly to handle the `allow_override` option used by |
277 | * all sub-analyzers when binding variables. |
278 | */ |
279 | Map<Var, Range> prev_bindings_; |
280 | |
281 | /*! \brief Known comparisons based on definitionally-true statements |
282 | * |
283 | * For example, a Let binding, or the range of an iterator. These |
284 | * known statements are always true, based on the definition site of |
285 | * the variable. e.g. A loop iterator may never exceed the bounds |
286 | * of its loop. |
287 | */ |
288 | std::vector<Comparison> knowns_; |
289 | |
290 | /*! \brief Known comparisons based on scoped conditions |
291 | * |
292 | * For example, the condition of an IfThenElse. These known |
293 | * statements may only be used within the scope of the statement |
294 | * from which they were derived. e.g. After exiting an IfThenElse, |
295 | * the condition may no longer be true. |
296 | */ |
297 | std::vector<Comparison> scoped_knowns_; |
298 | }; |
299 | |
300 | namespace { |
301 | |
302 | // Internal utility, return the CompareResult resulting from swapping |
303 | // the left-hand side with the right-hand side. |
304 | CompareResult Reverse(CompareResult res) { |
305 | switch (res) { |
306 | case CompareResult::kInconsistent: |
307 | return CompareResult::kInconsistent; |
308 | case CompareResult::kEQ: |
309 | return CompareResult::kEQ; |
310 | case CompareResult::kLT: |
311 | return CompareResult::kGT; |
312 | case CompareResult::kLE: |
313 | return CompareResult::kGE; |
314 | case CompareResult::kGT: |
315 | return CompareResult::kLT; |
316 | case CompareResult::kGE: |
317 | return CompareResult::kLE; |
318 | case CompareResult::kNE: |
319 | return CompareResult::kNE; |
320 | case CompareResult::kUnknown: |
321 | return CompareResult::kUnknown; |
322 | default: |
323 | LOG(FATAL) << "Invalid CompareResult: " << static_cast<int>(res); |
324 | } |
325 | } |
326 | |
327 | // Internal utility, return the CompareResult resulting from negating |
328 | // the comparison. |
329 | CompareResult Negate(CompareResult res) { |
330 | switch (res) { |
331 | case CompareResult::kInconsistent: |
332 | return CompareResult::kInconsistent; |
333 | case CompareResult::kUnknown: |
334 | return CompareResult::kUnknown; |
335 | default: |
336 | return CompareResult(~static_cast<int>(res) & static_cast<int>(CompareResult::kUnknown)); |
337 | } |
338 | } |
339 | |
340 | // Internal utility, extract constant offsets out of the two sides of |
341 | // a comparison. Given lhs and rhs, return a tuple of three elements |
342 | // (lhs_inner, rhs_inner, offset), such that (lhs OP rhs) and |
343 | // (lhs_inner OP rhs_inner + offset) are equivalent. |
344 | std::tuple<PrimExpr, PrimExpr, int64_t> (const PrimExpr& lhs, const PrimExpr& rhs) { |
345 | auto = [](const PrimExpr& expr) -> std::pair<PrimExpr, int64_t> { |
346 | PVar<PrimExpr> x; |
347 | PVar<IntImm> c; |
348 | if ((x + c).Match(expr)) { |
349 | return {x.Eval(), c.Eval()->value}; |
350 | } else if ((x - c).Match(expr)) { |
351 | return {x.Eval(), -c.Eval()->value}; |
352 | } else if (c.Match(expr)) { |
353 | return {0, c.Eval()->value}; |
354 | } else { |
355 | return {expr, 0}; |
356 | } |
357 | }; |
358 | |
359 | auto lhs_split = extract_offset(lhs); |
360 | auto rhs_split = extract_offset(rhs); |
361 | return {lhs_split.first, rhs_split.first, rhs_split.second - lhs_split.second}; |
362 | } |
363 | |
364 | } // namespace |
365 | |
366 | std::optional<TransitiveComparisonAnalyzer::Impl::Comparison> |
367 | TransitiveComparisonAnalyzer::Impl::FromExpr(const PrimExpr& expr) { |
368 | CompareResult res; |
369 | PVar<PrimExpr> x, y; |
370 | if ((x <= y).Match(expr)) { |
371 | res = CompareResult::kLE; |
372 | } else if ((x >= y).Match(expr)) { |
373 | res = CompareResult::kGE; |
374 | } else if ((x < y).Match(expr)) { |
375 | res = CompareResult::kLT; |
376 | } else if ((x > y).Match(expr)) { |
377 | res = CompareResult::kGT; |
378 | } else if ((x == y).Match(expr)) { |
379 | res = CompareResult::kEQ; |
380 | } else if ((x != y).Match(expr)) { |
381 | res = CompareResult::kNE; |
382 | } else { |
383 | return std::nullopt; |
384 | } |
385 | |
386 | PrimExpr lhs_expr = x.Eval(); |
387 | PrimExpr rhs_expr = y.Eval(); |
388 | |
389 | if (lhs_expr.as<IntImmNode>() && rhs_expr.as<IntImmNode>()) { |
390 | return std::nullopt; |
391 | } |
392 | |
393 | auto [lhs, rhs, offset] = ExtractOffsets(lhs_expr, rhs_expr); |
394 | Key lhs_key = ExprToKey(lhs); |
395 | Key rhs_key = ExprToKey(rhs); |
396 | |
397 | return Comparison(lhs_key, rhs_key, offset, res); |
398 | } |
399 | |
400 | TransitiveComparisonAnalyzer::Impl::Comparison::Comparison(Key lhs, Key rhs, int64_t offset, |
401 | CompareResult result) |
402 | : lhs_(lhs), rhs_(rhs), offset_(offset), result_(result) { |
403 | // Normalize the comparison to remove LT and GT expressions, |
404 | // reducing the number of operators that must be handled later. By |
405 | // eliminating LT and GT, instead of eliminating LE or GE, a |
406 | // potential off-by-one error is avoided. |
407 | // |
408 | // For floating-point numbers, (x < y + c1) and (y < z + c2) implies |
409 | // that (x < z + (c1 + c2)). For integer types, which the |
410 | // TransitiveComparisonAnalyzer is intended for use with integers, |
411 | // LT or GT can give a tighter constraint, though with a less |
412 | // convenient symmetry. |
413 | // |
414 | // i < j + c1, j < k + c2 |
415 | // i <= j + c1 - 1, j <= k + c2 - 1 |
416 | // i + 1 - c1 <= j, j <= k + c2 - 1 |
417 | // i + 1 - c1 <= k + c2 - 1 |
418 | // i <= k + c1 + c2 - 2 |
419 | // i < k + (c1 + c2 - 1) |
420 | // |
421 | // By always working with LE and GE comparisons, we avoid needing to |
422 | // handle the offset of one that would be introduced by LT and GT at |
423 | // all points of use. The only point of use for LT and GT is when |
424 | // normalizing comparisons (i.e. this constructor). |
425 | |
426 | if (result_ == CompareResult::kLT) { |
427 | result_ = CompareResult::kLE; |
428 | offset_ -= 1; |
429 | } |
430 | if (result_ == CompareResult::kGT) { |
431 | result_ = CompareResult::kGE; |
432 | offset_ += 1; |
433 | } |
434 | } |
435 | |
436 | std::optional<TransitiveComparisonAnalyzer::Impl::Key> |
437 | TransitiveComparisonAnalyzer::Impl::ExprToPreviousKey(const PrimExpr& expr) const { |
438 | auto it = expr_to_key.find(expr); |
439 | if (it != expr_to_key.end()) { |
440 | return it->second; |
441 | } else { |
442 | return std::nullopt; |
443 | } |
444 | } |
445 | |
446 | TransitiveComparisonAnalyzer::Impl::Key TransitiveComparisonAnalyzer::Impl::ExprToKey( |
447 | const PrimExpr& expr) { |
448 | if (auto prev = ExprToPreviousKey(expr)) { |
449 | return prev.value(); |
450 | } else { |
451 | Key new_key = Key(expr_to_key.size()); |
452 | expr_to_key[expr] = new_key; |
453 | return new_key; |
454 | } |
455 | } |
456 | |
457 | bool TransitiveComparisonAnalyzer::Impl::Comparison::IsNormalized() const { |
458 | // These < and > should be removed during normalization. See the |
459 | // `Comparison::Comparison` constructor for further details. |
460 | return result_ != CompareResult::kLT && result_ != CompareResult::kGT; |
461 | } |
462 | |
463 | std::optional<TransitiveComparisonAnalyzer::Impl::Comparison> |
464 | TransitiveComparisonAnalyzer::Impl::Comparison::WithLHS(Key new_lhs) const { |
465 | if (new_lhs == lhs_) { |
466 | return *this; |
467 | } else if (new_lhs == rhs_) { |
468 | return Comparison(rhs_, lhs_, -offset_, Reverse(result_)); |
469 | } else { |
470 | return std::nullopt; |
471 | } |
472 | } |
473 | |
474 | TransitiveComparisonAnalyzer::Impl::Comparison |
475 | TransitiveComparisonAnalyzer::Impl::Comparison::Negated() const { |
476 | return Comparison(lhs_, rhs_, offset_, Negate(result_)); |
477 | } |
478 | |
479 | bool TransitiveComparisonAnalyzer::Impl::Comparison::Implies( |
480 | const TransitiveComparisonAnalyzer::Impl::Comparison& other) const { |
481 | ICHECK(lhs_ == other.lhs_); |
482 | ICHECK(rhs_ == other.rhs_); |
483 | ICHECK(IsNormalized()); |
484 | ICHECK(other.IsNormalized()); |
485 | |
486 | if (result_ == other.result_ && offset_ == other.offset_) { |
487 | // if c1 == c2, x != y + c1 => x != y + c2 |
488 | // if c1 == c2, x == y + c1 => x == y + c2 |
489 | return true; |
490 | } |
491 | |
492 | if (other.result_ == CompareResult::kLE && offset_ <= other.offset_) { |
493 | if (result_ == CompareResult::kEQ || result_ == CompareResult::kLE) { |
494 | // if c1 <= c2, x <= y + c1 => x <= y + c2 |
495 | // if c1 <= c2, x == y + c1 => x <= y + c2 |
496 | return true; |
497 | } |
498 | } |
499 | |
500 | if (other.result_ == CompareResult::kGE && offset_ >= other.offset_) { |
501 | if (result_ == CompareResult::kEQ || result_ == CompareResult::kGE) { |
502 | // if c1 >= c2, x == y + c1 => x >= y + c2 |
503 | // if c1 >= c2, x >= y + c1 => x >= y + c2 |
504 | return true; |
505 | } |
506 | } |
507 | |
508 | if (other.result_ == CompareResult::kNE) { |
509 | if (result_ == CompareResult::kEQ && offset_ != other.offset_) { |
510 | // if c1 != c2, x == y + c1 => x != y + c2 |
511 | return true; |
512 | } |
513 | |
514 | if (result_ == CompareResult::kLE && offset_ < other.offset_) { |
515 | // if c1 < c2, x <= y + c1 => x < y + c2 => x != y + c2 |
516 | return true; |
517 | } |
518 | |
519 | if (result_ == CompareResult::kGE && offset_ > other.offset_) { |
520 | // if c1 != c2, x >= y + c1 => x > y + c2 => x != y + c2 |
521 | return true; |
522 | } |
523 | } |
524 | |
525 | return false; |
526 | } |
527 | |
528 | TransitiveComparisonAnalyzer::TransitiveComparisonAnalyzer() : impl_(std::make_unique<Impl>()) {} |
529 | TransitiveComparisonAnalyzer::~TransitiveComparisonAnalyzer() {} |
530 | |
531 | CompareResult TransitiveComparisonAnalyzer::TryCompare(const PrimExpr& lhs, const PrimExpr& rhs, |
532 | bool propagate_inequalities) { |
533 | return impl_->TryCompare(lhs, rhs, propagate_inequalities); |
534 | } |
535 | |
536 | void TransitiveComparisonAnalyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) { |
537 | impl_->Bind(var, expr, allow_override); |
538 | } |
539 | void TransitiveComparisonAnalyzer::Bind(const Var& var, const Range& range, bool allow_override) { |
540 | impl_->Bind(var, range, allow_override); |
541 | } |
542 | |
543 | std::function<void()> TransitiveComparisonAnalyzer::EnterConstraint(const PrimExpr& constraint) { |
544 | return impl_->EnterConstraint(constraint); |
545 | } |
546 | |
547 | void TransitiveComparisonAnalyzer::Impl::AddKnown(const PrimExpr& expr, |
548 | std::vector<Comparison>* vec) { |
549 | for (const auto& subexpr : ExtractConstraints(expr, false)) { |
550 | if (tir::SideEffect(expr) <= tir::CallEffectKind::kPure) { |
551 | if (auto cmp = FromExpr(subexpr)) { |
552 | vec->push_back(cmp.value()); |
553 | } |
554 | } |
555 | } |
556 | } |
557 | |
558 | void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const Range& range, |
559 | bool allow_override) { |
560 | auto it = prev_bindings_.find(var); |
561 | if (it != prev_bindings_.end()) { |
562 | ExprDeepEqual expr_equal; |
563 | bool differs_from_previous = !expr_equal(range->min, (*it).second->min) || |
564 | !expr_equal(range->extent, (*it).second->extent); |
565 | if (differs_from_previous) { |
566 | ICHECK(allow_override) << "Binding of variable " << var << " as " << range |
567 | << " conflicts with previous binding as " << (*it).second; |
568 | if (auto key = ExprToPreviousKey(var)) { |
569 | knowns_.erase(std::remove_if(knowns_.begin(), knowns_.end(), |
570 | [&](const auto& known) { return known.lhs_ == key.value(); }), |
571 | knowns_.end()); |
572 | } |
573 | } |
574 | } |
575 | |
576 | prev_bindings_.Set(var, range); |
577 | |
578 | if (is_const_int(range->extent, 1)) { |
579 | AddKnown(var == range->min, &knowns_); |
580 | } else { |
581 | AddKnown(var >= range->min, &knowns_); |
582 | AddKnown(var < range->min + range->extent, &knowns_); |
583 | } |
584 | } |
585 | |
586 | void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const PrimExpr& expr, |
587 | bool allow_override) { |
588 | Bind(var, Range::FromMinExtent(expr, 1), allow_override); |
589 | } |
590 | |
591 | std::function<void()> TransitiveComparisonAnalyzer::Impl::EnterConstraint(const PrimExpr& expr) { |
592 | size_t old_literal_size = scoped_knowns_.size(); |
593 | AddKnown(expr, &scoped_knowns_); |
594 | size_t new_literal_size = scoped_knowns_.size(); |
595 | |
596 | auto frecover = [old_literal_size, new_literal_size, this]() { |
597 | ICHECK_EQ(scoped_knowns_.size(), new_literal_size); |
598 | scoped_knowns_.erase(scoped_knowns_.begin() + old_literal_size, scoped_knowns_.end()); |
599 | }; |
600 | return frecover; |
601 | } |
602 | |
603 | CompareResult TransitiveComparisonAnalyzer::Impl::TryCompare(const PrimExpr& lhs_expr, |
604 | const PrimExpr& rhs_expr, |
605 | bool propagate_inequalities) const { |
606 | // Currently only supports integer checks |
607 | if (!lhs_expr.dtype().is_int() || !rhs_expr.dtype().is_int()) { |
608 | return CompareResult::kUnknown; |
609 | } |
610 | |
611 | // Bail out early if possible. This int check should have been |
612 | // constant-folded earlier, so this check shouldn't occur. |
613 | auto* x_int = lhs_expr.as<IntImmNode>(); |
614 | auto* y_int = rhs_expr.as<IntImmNode>(); |
615 | if (x_int && y_int) { |
616 | if (x_int->value < y_int->value) { |
617 | return CompareResult::kLT; |
618 | } else if (x_int->value > y_int->value) { |
619 | return CompareResult::kGT; |
620 | } else { |
621 | return CompareResult::kEQ; |
622 | } |
623 | } |
624 | |
625 | auto [lhs, rhs, offset] = ExtractOffsets(lhs_expr, rhs_expr); |
626 | auto lhs_key = ExprToPreviousKey(lhs); |
627 | auto rhs_key = ExprToPreviousKey(rhs); |
628 | |
629 | if (!lhs_key.has_value() || !rhs_key.has_value()) { |
630 | return CompareResult::kUnknown; |
631 | } |
632 | |
633 | auto lhs_to_rhs = [&]() { |
634 | if (propagate_inequalities) { |
635 | return CollectIndirectComparisons(lhs_key.value(), rhs_key.value()); |
636 | } else { |
637 | return CollectDirectComparisons(lhs_key.value(), rhs_key.value()); |
638 | } |
639 | }(); |
640 | return MergeComparisons(lhs_to_rhs, offset); |
641 | } |
642 | |
643 | std::vector<TransitiveComparisonAnalyzer::Impl::Comparison> |
644 | TransitiveComparisonAnalyzer::Impl::CollectDirectComparisons(Key lhs_key, Key rhs_key) const { |
645 | std::vector<Comparison> output; |
646 | |
647 | auto append_known = [&](Comparison cmp) { |
648 | if (auto normalized = cmp.WithLHS(lhs_key)) { |
649 | if (normalized.value().rhs_ == rhs_key) { |
650 | output.push_back(normalized.value()); |
651 | } |
652 | } |
653 | }; |
654 | |
655 | for (const auto& known : knowns_) { |
656 | append_known(known); |
657 | } |
658 | for (const auto& known : scoped_knowns_) { |
659 | append_known(known); |
660 | } |
661 | |
662 | return output; |
663 | } |
664 | |
665 | std::vector<TransitiveComparisonAnalyzer::Impl::Comparison> |
666 | TransitiveComparisonAnalyzer::Impl::CollectIndirectComparisons(Key lhs_key, Key rhs_key) const { |
667 | auto output = DFSFromLHS(lhs_key, rhs_key); |
668 | for (Comparison cmp : DFSFromLHS(rhs_key, lhs_key)) { |
669 | auto opt_normalized = cmp.WithLHS(lhs_key); |
670 | ICHECK(opt_normalized.has_value()); |
671 | output.push_back(opt_normalized.value()); |
672 | } |
673 | return output; |
674 | } |
675 | |
676 | std::vector<TransitiveComparisonAnalyzer::Impl::Comparison> |
677 | TransitiveComparisonAnalyzer::Impl::DFSFromLHS(Key lhs_key, Key rhs_key) const { |
678 | // Everything in `to_visit` has lhs as its lhs. |
679 | std::unordered_set<Key> seen; |
680 | std::unordered_set<Key> to_visit; |
681 | std::unordered_map<Key, std::vector<Comparison>> compared_to_lhs; |
682 | |
683 | // Utility function to add a new known statement |
684 | auto declare_known = [&](Comparison cmp) { |
685 | std::vector<Comparison>& knowns = compared_to_lhs[cmp.rhs_]; |
686 | |
687 | // The comparison adds no new information, no modification |
688 | // required. |
689 | for (auto& prev_known : knowns) { |
690 | if (prev_known.Implies(cmp)) { |
691 | return; |
692 | } |
693 | } |
694 | |
695 | // New information may require visiting a new expression. |
696 | if (cmp.rhs_ != rhs_key && !seen.count(cmp.rhs_)) { |
697 | to_visit.insert(cmp.rhs_); |
698 | seen.insert(cmp.rhs_); |
699 | } |
700 | |
701 | // This comparison is a stronger version of a previous constraint. |
702 | // Therefore, replace the old version entirely. |
703 | for (auto& prev_known : knowns) { |
704 | if (cmp.Implies(prev_known)) { |
705 | prev_known = cmp; |
706 | return; |
707 | } |
708 | } |
709 | |
710 | // Neither a superset nor a subset of previously known |
711 | // constraints, must be tracked separately. |
712 | knowns.push_back(cmp); |
713 | }; |
714 | |
715 | // Initialize the search based on any known (in)equalities that use |
716 | // the LHS of the comparison. |
717 | for (const auto& known : knowns_) { |
718 | if (auto normalized = known.WithLHS(lhs_key)) { |
719 | declare_known(normalized.value()); |
720 | } |
721 | } |
722 | for (const auto& known : scoped_knowns_) { |
723 | if (auto normalized = known.WithLHS(lhs_key)) { |
724 | declare_known(normalized.value()); |
725 | } |
726 | } |
727 | |
728 | // Walk through the space of all comparisons that can be made with |
729 | // LHS. |
730 | while (to_visit.size()) { |
731 | Key middle_key = *to_visit.begin(); |
732 | to_visit.erase(to_visit.begin()); |
733 | |
734 | std::vector<Comparison>& prev_knowns_using_middle = compared_to_lhs.at(middle_key); |
735 | ICHECK(compared_to_lhs.count(middle_key)); |
736 | |
737 | std::vector<Comparison> new_knowns_using_lhs; |
738 | |
739 | auto attempt_transitive = [&](Comparison cmp) { |
740 | ICHECK(cmp.IsNormalized()); |
741 | |
742 | Key right_key = cmp.rhs_; |
743 | |
744 | if (right_key == lhs_key) { |
745 | return; |
746 | } |
747 | |
748 | for (const auto& prev : prev_knowns_using_middle) { |
749 | CompareResult new_result = CompareResult::kUnknown; |
750 | int64_t new_offset = prev.offset_ + cmp.offset_; |
751 | |
752 | if (prev.result_ == CompareResult::kEQ) { |
753 | // x == y + c1 && y OP z + c2, x OP z + (c1 + c2) |
754 | new_result = cmp.result_; |
755 | } else if (cmp.result_ == CompareResult::kEQ) { |
756 | // x OP y + c1 && y == z + c2, x OP z + (c1 + c2) |
757 | new_result = prev.result_; |
758 | } else if (prev.result_ == cmp.result_ && |
759 | (prev.result_ == CompareResult::kLE || prev.result_ == CompareResult::kGE)) { |
760 | // x <= y + c1 && y <= z + c2, x <= z + (c1 + c2) |
761 | // x >= y + c1 && y >= z + c2, x >= z + (c1 + c2) |
762 | // |
763 | // This condition is much simpler to write than the |
764 | // equivalent handling of < or of >, which is why the |
765 | // inequalities are normalized to <= and to >=. See |
766 | // `TransitiveComparisonAnalyzer::Impl::Comparison::Comparison` |
767 | // for further details. |
768 | new_result = prev.result_; |
769 | } |
770 | |
771 | if (new_result != CompareResult::kUnknown) { |
772 | Comparison new_known(lhs_key, right_key, new_offset, new_result); |
773 | new_knowns_using_lhs.push_back(new_known); |
774 | } |
775 | } |
776 | }; |
777 | |
778 | // Attempt to prove a new comparison using one of the original |
779 | // known comparisons. We want to find a known such that |
780 | // `(LHS OP1 middle) && (middle OP2 right)` can be simplified |
781 | // into `(LHS OP3 right)`. |
782 | // |
783 | // Note: The right side is this step is not necessarily the RHS of |
784 | // the comparison we're trying to prove, as we may need to find |
785 | // intermediate comparisons first. For example, if we know that |
786 | // `a<=b`, `b<=c`, and `c<=d`, and we wish to prove that `a<=d`, |
787 | // we must first combine `a<=b` and `b<=c` into `a<=c`. During |
788 | // this first step, `b` is the "middle" and `c` is the "right". |
789 | // The next step can then combind `a<=c` and `c<=d` into `a<=d`. |
790 | for (const auto& known : knowns_) { |
791 | if (auto cmp = known.WithLHS(middle_key)) { |
792 | attempt_transitive(cmp.value()); |
793 | } |
794 | } |
795 | |
796 | for (const auto& known : scoped_knowns_) { |
797 | if (auto cmp = known.WithLHS(middle_key)) { |
798 | attempt_transitive(cmp.value()); |
799 | } |
800 | } |
801 | |
802 | // Collect together all new knowns, marking new nodes for visiting |
803 | // as needed. |
804 | for (const auto& new_known : new_knowns_using_lhs) { |
805 | declare_known(new_known); |
806 | } |
807 | } |
808 | |
809 | if (auto it = compared_to_lhs.find(rhs_key); it != compared_to_lhs.end()) { |
810 | return it->second; |
811 | } else { |
812 | // There are known comparisons involving the LHS and the RHS, but |
813 | // no path that connects the two expressions. |
814 | return {}; |
815 | } |
816 | } |
817 | |
818 | CompareResult TransitiveComparisonAnalyzer::Impl::MergeComparisons( |
819 | const std::vector<Comparison>& lhs_to_rhs, int64_t offset) const { |
820 | // Just because we found a comparison involving LHS and RHS doesn't |
821 | // mean that it's useful. e.g. Knowing that `x < y` doesn't let us |
822 | // prove whether `x + 5 < y`. |
823 | CompareResult result = CompareResult::kUnknown; |
824 | for (const auto& cmp : lhs_to_rhs) { |
825 | switch (cmp.result_) { |
826 | case CompareResult::kInconsistent: |
827 | result = CompareResult::kInconsistent; |
828 | break; |
829 | |
830 | case CompareResult::kEQ: |
831 | if (offset == cmp.offset_) { |
832 | result = result & CompareResult::kEQ; |
833 | } else { |
834 | result = result & CompareResult::kNE; |
835 | } |
836 | break; |
837 | |
838 | case CompareResult::kLE: |
839 | if (cmp.offset_ < offset) { |
840 | result = result & CompareResult::kLT; |
841 | } else if (cmp.offset_ <= offset) { |
842 | result = result & CompareResult::kLE; |
843 | } |
844 | break; |
845 | |
846 | case CompareResult::kGE: |
847 | if (cmp.offset_ > offset) { |
848 | result = result & CompareResult::kGT; |
849 | } else if (cmp.offset_ >= offset) { |
850 | result = result & CompareResult::kGE; |
851 | } |
852 | break; |
853 | |
854 | case CompareResult::kNE: |
855 | if (offset == cmp.offset_) { |
856 | result = result & CompareResult::kNE; |
857 | } |
858 | break; |
859 | |
860 | case CompareResult::kUnknown: |
861 | break; |
862 | |
863 | case CompareResult::kGT: |
864 | case CompareResult::kLT: |
865 | LOG(FATAL) << "Internal error, normalized comparisons should only include <= and >=" ; |
866 | |
867 | default: |
868 | LOG(FATAL) << "Invalid CompareResult: " << static_cast<int>(cmp.result_); |
869 | } |
870 | } |
871 | |
872 | return result; |
873 | } |
874 | |
875 | } // namespace arith |
876 | } // namespace tvm |
877 | |