1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19/*!
20 * \file tvm/arith/transitive_comparison_analyzer.cc
21 */
22
23#include <tvm/arith/analyzer.h>
24#include <tvm/tir/analysis.h>
25#include <tvm/tir/expr.h>
26
27#include <optional>
28#include <vector>
29
30#include "constraint_extract.h"
31#include "pattern_match.h"
32
33namespace tvm {
34namespace arith {
35
36using namespace tir;
37
38class TransitiveComparisonAnalyzer::Impl {
39 public:
40 /* \brief Using previously specified knowns, compare the expressions provided
41 *
42 * \param lhs The left-hand side of the comparison
43 *
44 * \param rhs The right-hand side of the comparison
45 *
46 * \param propagate_inequalities If true, attempt to find a sequence
47 * of transitive inequalities that allow the lhs and rhs to be
48 * compared. If false, only use the known comparison that have been
49 * directly provided. Using `propagate_inequalities = false` is
50 * roughly equivalent to comparing against all known values with
51 * `ExprDeepEqual`, but also allowing for constant offsets on either
52 * side of the inequality.
53 *
54 * \return The most specific result that can be proven about the
55 * comparison. If nothing can be proven, returns kUnknown.
56 */
57 CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs,
58 bool propagate_inequalities = true) const;
59
60 /*! \brief Bind a variable as being equal to a known expression
61 *
62 * \param var The variable of interest.
63 * \param expr The bound expression
64 * \param allow_override Whether to allow override of existing information.
65 */
66 void Bind(const tir::Var& var, const PrimExpr& expr, bool allow_override = false);
67
68 /*! \brief Bind a variable as being within a specified range
69 *
70 * \param var The variable of interest.
71 * \param range The known range
72 * \param allow_override Whether to allow override of existing information.
73 */
74 void Bind(const tir::Var& var, const Range& expr, bool allow_override = false);
75
76 /*!
77 * \brief Update the internal state to enter constraint.
78 * \param constraint A constraint expression.
79 *
80 * \return An exit function that must be called to cleanup. May be
81 * `nullptr`, if no cleanup is required.
82 */
83 std::function<void()> EnterConstraint(const PrimExpr& expr);
84
85 private:
86 /* \brief Internal representation of a PrimExpr
87 *
88 * The Key enum serves two purposes.
89 *
90 * 1. Providing efficiency, as compared to a PrimExpr. Two keys are
91 * equal if and only if the corresponding PrimExprs would satisfy
92 * ExprDeepEqual. This allows two expressions to be checked for
93 * equivalency, without requiring a call to ExprDeepEqual for
94 * each comparison.
95 *
96 * 2. Providing type-safety, as compared to using `size_t` directly.
97 * Requiring an explicit conversion from an integer to a Key
98 * prevents accidental comparisons, especially if both loop
99 * iterators and Keys are used in the same scope.
100 *
101 * A Key should only be obtained using the methods `ExprToKey` and
102 * `ExprToPreviousKey`.
103 */
104 enum class Key : size_t {};
105
106 /*! \brief Convert an expression to internal representation
107 *
108 * If the expression has previously been converted to the internal
109 * representation, returns the same Key as has been used previously.
110 * Otherwise, generate and return a new Key.
111 *
112 * \param expr The PrimExpr to be converted
113 *
114 * \returns The Key representing the expression
115 *
116 * \see ExprToPreviousKey
117 */
118 Key ExprToKey(const PrimExpr& expr);
119
120 /*! \brief Convert an expression to internal representation
121 *
122 * If the expression has previously been converted to the internal
123 * representation, returns the same Key as has been used previously.
124 * Otherwise, return `std::nullopt`.
125 *
126 * \param expr The PrimExpr to be converted
127 *
128 * \returns The Key representing the expression, if one exists.
129 *
130 * \see ExprToKey
131 */
132 std::optional<Key> ExprToPreviousKey(const PrimExpr& expr) const;
133
134 /*! \brief The mapping from expression to Key
135 *
136 * Should not be used directly. Instead, use the helper functions
137 * `ExprToKey` and `ExprToPreviousKey`.
138 *
139 * \see ExprToKey
140 * \see ExprToPreviousKey
141 */
142 std::unordered_map<PrimExpr, Key, StructuralHash, StructuralEqual> expr_to_key;
143
144 /*! \brief Internal representation of a comparison operator */
145 struct Comparison {
146 /*! \brief Construct a comparison that represents `lhs OP rhs +
147 * offset`, where the operation is specified by the CompareResult.
148 */
149 Comparison(Key lhs, Key rhs, int64_t offset, CompareResult result);
150
151 /*! \brief Utility function to validate that all GT and LT results
152 * have been normalized out
153 */
154 bool IsNormalized() const;
155
156 /*! \brief Move the specified expression to the LHS.
157 *
158 * \param new_lhs The argument that should be moved to the LHS of the
159 * comparison.
160 *
161 * \return If possible, returns a comparison that is equivalent to
162 * the current comparison, but with the specified LHS. If not
163 * possible, returns nullopt.
164 */
165 std::optional<Comparison> WithLHS(Key new_lhs) const;
166
167 /*! \brief Create the negation of the current comparison */
168 Comparison Negated() const;
169
170 /*! \brief Check the this comparison implies
171 *
172 * Returns true if this comparison being true implies that the
173 * other comparison must also be true. Returns false if the other
174 * comparison cannot be shown to be true.
175 */
176 bool Implies(const Comparison& other) const;
177
178 // The LHS of the comparison
179 Key lhs_;
180
181 // The RHS of the comparison, not including any constant offset.
182 Key rhs_;
183
184 // Additive offset on rhs
185 int64_t offset_{0};
186
187 // The comparison operator.
188 CompareResult result_{CompareResult::kInconsistent};
189 };
190
191 /*! \brief Generate a Comparison representing the given expression */
192 std::optional<Comparison> FromExpr(const PrimExpr& expr);
193
194 /*! \brief Utility function used by Bind and EnterConstraint
195 *
196 * \param expr The comparison expression, to be converted into
197 * internal Comparison objects.
198 *
199 * \param vec The vector to which the Comparison objects should be
200 * appended.
201 */
202 void AddKnown(const PrimExpr& expr, std::vector<Comparison>* vec);
203
204 /*! Collect known comparisons between LHS and RHS, without propagation
205 *
206 * Allows the internal representation to handle any constant
207 * offsets, without searching for a sequence of inequalities.
208 *
209 * \param lhs_key The left-hand side of the comparison
210 *
211 * \param rhs_key The right-hand side of the comparison
212 *
213 * \returns A subset of `knowns_` and `scoped_knowns_`, filtered to
214 * only include comparisons between `lhs_key` and `rhs_key`,
215 * normalized such that `lhs_key` is on the left-hand side.
216 */
217 std::vector<Comparison> CollectDirectComparisons(Key lhs_key, Key rhs_key) const;
218
219 /*! Collect known comparisons between LHS and RHS, with propagation
220 *
221 * \param lhs_key The left-hand side of the comparison
222 *
223 * \param rhs_key The right-hand side of the comparison
224 *
225 * \returns All comparisons between `lhs_key` and `rhs_key`,
226 * including the explicitly-provided comparisons in `knowns_` and
227 * `scoped_knowns_`, and comparisons provable through a series of
228 * comparisons through other values. All comparisons returned are
229 * between `lhs_key` and `rhs_key`, and are normalized such that
230 * `lhs_key` is on the left-hand side.
231 */
232 std::vector<Comparison> CollectIndirectComparisons(Key lhs_key, Key rhs_key) const;
233
234 /*! \brief Internal function used by CollectIndirectComparisons
235 *
236 * Perform a depth-first search through the space of known
237 * expressions, starting at the LHS of a comparison. In this
238 * search, each expression is a node of a graph, and each known
239 * comparison is an edge of the graph.
240 *
241 * For example, suppose we have previous knowns of (A<=B), (B<=C+1)
242 * and (C<=D-5). The expressions [A,B,C,D] are the nodes of the
243 * search space. Each comparison is an edge connecting two
244 * expressions, such as (B<=C+1) connecting the expressions B and D.
245 * If we are attempting to compare expressions A and D, a search
246 * starting at expression A could follow each edge until reaching
247 * expression D, then combine the comparisons that compose the path
248 * into the expression A<=D-4.
249 *
250 * \param lhs_key The left-hand side of the comparison
251 *
252 * \param rhs_key The right-hand side of the comparison
253 *
254 * \returns A vector of comparisons between the two expressions.
255 */
256 std::vector<Comparison> DFSFromLHS(Key lhs_key, Key rhs_key) const;
257
258 /*! \brief Combine a set of comparisons that share a LHS and RHS
259 *
260 * \param lhs_to_rhs The comparisons to merge. These should all
261 * have the same LHS and RHS. This parameter will typically be the
262 * result from `CollectDirectComparisons` or
263 * `CollectIndirectComparisons`.
264 *
265 * \param offset The constant offset in the comparison being proven.
266 * This is extracted from any additive/subtractive constants in the
267 * `PrimExpr` arguments to `TryCompare`.
268 *
269 * \returns The possible comparisons between LHS and RHS provided
270 * inequalities.
271 */
272 CompareResult MergeComparisons(const std::vector<Comparison>& lhs_to_rhs, int64_t offset) const;
273
274 /*! \brief Previous Range bindings
275 *
276 * Tracked separatedly to handle the `allow_override` option used by
277 * all sub-analyzers when binding variables.
278 */
279 Map<Var, Range> prev_bindings_;
280
281 /*! \brief Known comparisons based on definitionally-true statements
282 *
283 * For example, a Let binding, or the range of an iterator. These
284 * known statements are always true, based on the definition site of
285 * the variable. e.g. A loop iterator may never exceed the bounds
286 * of its loop.
287 */
288 std::vector<Comparison> knowns_;
289
290 /*! \brief Known comparisons based on scoped conditions
291 *
292 * For example, the condition of an IfThenElse. These known
293 * statements may only be used within the scope of the statement
294 * from which they were derived. e.g. After exiting an IfThenElse,
295 * the condition may no longer be true.
296 */
297 std::vector<Comparison> scoped_knowns_;
298};
299
300namespace {
301
302// Internal utility, return the CompareResult resulting from swapping
303// the left-hand side with the right-hand side.
304CompareResult Reverse(CompareResult res) {
305 switch (res) {
306 case CompareResult::kInconsistent:
307 return CompareResult::kInconsistent;
308 case CompareResult::kEQ:
309 return CompareResult::kEQ;
310 case CompareResult::kLT:
311 return CompareResult::kGT;
312 case CompareResult::kLE:
313 return CompareResult::kGE;
314 case CompareResult::kGT:
315 return CompareResult::kLT;
316 case CompareResult::kGE:
317 return CompareResult::kLE;
318 case CompareResult::kNE:
319 return CompareResult::kNE;
320 case CompareResult::kUnknown:
321 return CompareResult::kUnknown;
322 default:
323 LOG(FATAL) << "Invalid CompareResult: " << static_cast<int>(res);
324 }
325}
326
327// Internal utility, return the CompareResult resulting from negating
328// the comparison.
329CompareResult Negate(CompareResult res) {
330 switch (res) {
331 case CompareResult::kInconsistent:
332 return CompareResult::kInconsistent;
333 case CompareResult::kUnknown:
334 return CompareResult::kUnknown;
335 default:
336 return CompareResult(~static_cast<int>(res) & static_cast<int>(CompareResult::kUnknown));
337 }
338}
339
340// Internal utility, extract constant offsets out of the two sides of
341// a comparison. Given lhs and rhs, return a tuple of three elements
342// (lhs_inner, rhs_inner, offset), such that (lhs OP rhs) and
343// (lhs_inner OP rhs_inner + offset) are equivalent.
344std::tuple<PrimExpr, PrimExpr, int64_t> ExtractOffsets(const PrimExpr& lhs, const PrimExpr& rhs) {
345 auto extract_offset = [](const PrimExpr& expr) -> std::pair<PrimExpr, int64_t> {
346 PVar<PrimExpr> x;
347 PVar<IntImm> c;
348 if ((x + c).Match(expr)) {
349 return {x.Eval(), c.Eval()->value};
350 } else if ((x - c).Match(expr)) {
351 return {x.Eval(), -c.Eval()->value};
352 } else if (c.Match(expr)) {
353 return {0, c.Eval()->value};
354 } else {
355 return {expr, 0};
356 }
357 };
358
359 auto lhs_split = extract_offset(lhs);
360 auto rhs_split = extract_offset(rhs);
361 return {lhs_split.first, rhs_split.first, rhs_split.second - lhs_split.second};
362}
363
364} // namespace
365
366std::optional<TransitiveComparisonAnalyzer::Impl::Comparison>
367TransitiveComparisonAnalyzer::Impl::FromExpr(const PrimExpr& expr) {
368 CompareResult res;
369 PVar<PrimExpr> x, y;
370 if ((x <= y).Match(expr)) {
371 res = CompareResult::kLE;
372 } else if ((x >= y).Match(expr)) {
373 res = CompareResult::kGE;
374 } else if ((x < y).Match(expr)) {
375 res = CompareResult::kLT;
376 } else if ((x > y).Match(expr)) {
377 res = CompareResult::kGT;
378 } else if ((x == y).Match(expr)) {
379 res = CompareResult::kEQ;
380 } else if ((x != y).Match(expr)) {
381 res = CompareResult::kNE;
382 } else {
383 return std::nullopt;
384 }
385
386 PrimExpr lhs_expr = x.Eval();
387 PrimExpr rhs_expr = y.Eval();
388
389 if (lhs_expr.as<IntImmNode>() && rhs_expr.as<IntImmNode>()) {
390 return std::nullopt;
391 }
392
393 auto [lhs, rhs, offset] = ExtractOffsets(lhs_expr, rhs_expr);
394 Key lhs_key = ExprToKey(lhs);
395 Key rhs_key = ExprToKey(rhs);
396
397 return Comparison(lhs_key, rhs_key, offset, res);
398}
399
400TransitiveComparisonAnalyzer::Impl::Comparison::Comparison(Key lhs, Key rhs, int64_t offset,
401 CompareResult result)
402 : lhs_(lhs), rhs_(rhs), offset_(offset), result_(result) {
403 // Normalize the comparison to remove LT and GT expressions,
404 // reducing the number of operators that must be handled later. By
405 // eliminating LT and GT, instead of eliminating LE or GE, a
406 // potential off-by-one error is avoided.
407 //
408 // For floating-point numbers, (x < y + c1) and (y < z + c2) implies
409 // that (x < z + (c1 + c2)). For integer types, which the
410 // TransitiveComparisonAnalyzer is intended for use with integers,
411 // LT or GT can give a tighter constraint, though with a less
412 // convenient symmetry.
413 //
414 // i < j + c1, j < k + c2
415 // i <= j + c1 - 1, j <= k + c2 - 1
416 // i + 1 - c1 <= j, j <= k + c2 - 1
417 // i + 1 - c1 <= k + c2 - 1
418 // i <= k + c1 + c2 - 2
419 // i < k + (c1 + c2 - 1)
420 //
421 // By always working with LE and GE comparisons, we avoid needing to
422 // handle the offset of one that would be introduced by LT and GT at
423 // all points of use. The only point of use for LT and GT is when
424 // normalizing comparisons (i.e. this constructor).
425
426 if (result_ == CompareResult::kLT) {
427 result_ = CompareResult::kLE;
428 offset_ -= 1;
429 }
430 if (result_ == CompareResult::kGT) {
431 result_ = CompareResult::kGE;
432 offset_ += 1;
433 }
434}
435
436std::optional<TransitiveComparisonAnalyzer::Impl::Key>
437TransitiveComparisonAnalyzer::Impl::ExprToPreviousKey(const PrimExpr& expr) const {
438 auto it = expr_to_key.find(expr);
439 if (it != expr_to_key.end()) {
440 return it->second;
441 } else {
442 return std::nullopt;
443 }
444}
445
446TransitiveComparisonAnalyzer::Impl::Key TransitiveComparisonAnalyzer::Impl::ExprToKey(
447 const PrimExpr& expr) {
448 if (auto prev = ExprToPreviousKey(expr)) {
449 return prev.value();
450 } else {
451 Key new_key = Key(expr_to_key.size());
452 expr_to_key[expr] = new_key;
453 return new_key;
454 }
455}
456
457bool TransitiveComparisonAnalyzer::Impl::Comparison::IsNormalized() const {
458 // These < and > should be removed during normalization. See the
459 // `Comparison::Comparison` constructor for further details.
460 return result_ != CompareResult::kLT && result_ != CompareResult::kGT;
461}
462
463std::optional<TransitiveComparisonAnalyzer::Impl::Comparison>
464TransitiveComparisonAnalyzer::Impl::Comparison::WithLHS(Key new_lhs) const {
465 if (new_lhs == lhs_) {
466 return *this;
467 } else if (new_lhs == rhs_) {
468 return Comparison(rhs_, lhs_, -offset_, Reverse(result_));
469 } else {
470 return std::nullopt;
471 }
472}
473
474TransitiveComparisonAnalyzer::Impl::Comparison
475TransitiveComparisonAnalyzer::Impl::Comparison::Negated() const {
476 return Comparison(lhs_, rhs_, offset_, Negate(result_));
477}
478
479bool TransitiveComparisonAnalyzer::Impl::Comparison::Implies(
480 const TransitiveComparisonAnalyzer::Impl::Comparison& other) const {
481 ICHECK(lhs_ == other.lhs_);
482 ICHECK(rhs_ == other.rhs_);
483 ICHECK(IsNormalized());
484 ICHECK(other.IsNormalized());
485
486 if (result_ == other.result_ && offset_ == other.offset_) {
487 // if c1 == c2, x != y + c1 => x != y + c2
488 // if c1 == c2, x == y + c1 => x == y + c2
489 return true;
490 }
491
492 if (other.result_ == CompareResult::kLE && offset_ <= other.offset_) {
493 if (result_ == CompareResult::kEQ || result_ == CompareResult::kLE) {
494 // if c1 <= c2, x <= y + c1 => x <= y + c2
495 // if c1 <= c2, x == y + c1 => x <= y + c2
496 return true;
497 }
498 }
499
500 if (other.result_ == CompareResult::kGE && offset_ >= other.offset_) {
501 if (result_ == CompareResult::kEQ || result_ == CompareResult::kGE) {
502 // if c1 >= c2, x == y + c1 => x >= y + c2
503 // if c1 >= c2, x >= y + c1 => x >= y + c2
504 return true;
505 }
506 }
507
508 if (other.result_ == CompareResult::kNE) {
509 if (result_ == CompareResult::kEQ && offset_ != other.offset_) {
510 // if c1 != c2, x == y + c1 => x != y + c2
511 return true;
512 }
513
514 if (result_ == CompareResult::kLE && offset_ < other.offset_) {
515 // if c1 < c2, x <= y + c1 => x < y + c2 => x != y + c2
516 return true;
517 }
518
519 if (result_ == CompareResult::kGE && offset_ > other.offset_) {
520 // if c1 != c2, x >= y + c1 => x > y + c2 => x != y + c2
521 return true;
522 }
523 }
524
525 return false;
526}
527
528TransitiveComparisonAnalyzer::TransitiveComparisonAnalyzer() : impl_(std::make_unique<Impl>()) {}
529TransitiveComparisonAnalyzer::~TransitiveComparisonAnalyzer() {}
530
531CompareResult TransitiveComparisonAnalyzer::TryCompare(const PrimExpr& lhs, const PrimExpr& rhs,
532 bool propagate_inequalities) {
533 return impl_->TryCompare(lhs, rhs, propagate_inequalities);
534}
535
536void TransitiveComparisonAnalyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) {
537 impl_->Bind(var, expr, allow_override);
538}
539void TransitiveComparisonAnalyzer::Bind(const Var& var, const Range& range, bool allow_override) {
540 impl_->Bind(var, range, allow_override);
541}
542
543std::function<void()> TransitiveComparisonAnalyzer::EnterConstraint(const PrimExpr& constraint) {
544 return impl_->EnterConstraint(constraint);
545}
546
547void TransitiveComparisonAnalyzer::Impl::AddKnown(const PrimExpr& expr,
548 std::vector<Comparison>* vec) {
549 for (const auto& subexpr : ExtractConstraints(expr, false)) {
550 if (tir::SideEffect(expr) <= tir::CallEffectKind::kPure) {
551 if (auto cmp = FromExpr(subexpr)) {
552 vec->push_back(cmp.value());
553 }
554 }
555 }
556}
557
558void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const Range& range,
559 bool allow_override) {
560 auto it = prev_bindings_.find(var);
561 if (it != prev_bindings_.end()) {
562 ExprDeepEqual expr_equal;
563 bool differs_from_previous = !expr_equal(range->min, (*it).second->min) ||
564 !expr_equal(range->extent, (*it).second->extent);
565 if (differs_from_previous) {
566 ICHECK(allow_override) << "Binding of variable " << var << " as " << range
567 << " conflicts with previous binding as " << (*it).second;
568 if (auto key = ExprToPreviousKey(var)) {
569 knowns_.erase(std::remove_if(knowns_.begin(), knowns_.end(),
570 [&](const auto& known) { return known.lhs_ == key.value(); }),
571 knowns_.end());
572 }
573 }
574 }
575
576 prev_bindings_.Set(var, range);
577
578 if (is_const_int(range->extent, 1)) {
579 AddKnown(var == range->min, &knowns_);
580 } else {
581 AddKnown(var >= range->min, &knowns_);
582 AddKnown(var < range->min + range->extent, &knowns_);
583 }
584}
585
586void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const PrimExpr& expr,
587 bool allow_override) {
588 Bind(var, Range::FromMinExtent(expr, 1), allow_override);
589}
590
591std::function<void()> TransitiveComparisonAnalyzer::Impl::EnterConstraint(const PrimExpr& expr) {
592 size_t old_literal_size = scoped_knowns_.size();
593 AddKnown(expr, &scoped_knowns_);
594 size_t new_literal_size = scoped_knowns_.size();
595
596 auto frecover = [old_literal_size, new_literal_size, this]() {
597 ICHECK_EQ(scoped_knowns_.size(), new_literal_size);
598 scoped_knowns_.erase(scoped_knowns_.begin() + old_literal_size, scoped_knowns_.end());
599 };
600 return frecover;
601}
602
603CompareResult TransitiveComparisonAnalyzer::Impl::TryCompare(const PrimExpr& lhs_expr,
604 const PrimExpr& rhs_expr,
605 bool propagate_inequalities) const {
606 // Currently only supports integer checks
607 if (!lhs_expr.dtype().is_int() || !rhs_expr.dtype().is_int()) {
608 return CompareResult::kUnknown;
609 }
610
611 // Bail out early if possible. This int check should have been
612 // constant-folded earlier, so this check shouldn't occur.
613 auto* x_int = lhs_expr.as<IntImmNode>();
614 auto* y_int = rhs_expr.as<IntImmNode>();
615 if (x_int && y_int) {
616 if (x_int->value < y_int->value) {
617 return CompareResult::kLT;
618 } else if (x_int->value > y_int->value) {
619 return CompareResult::kGT;
620 } else {
621 return CompareResult::kEQ;
622 }
623 }
624
625 auto [lhs, rhs, offset] = ExtractOffsets(lhs_expr, rhs_expr);
626 auto lhs_key = ExprToPreviousKey(lhs);
627 auto rhs_key = ExprToPreviousKey(rhs);
628
629 if (!lhs_key.has_value() || !rhs_key.has_value()) {
630 return CompareResult::kUnknown;
631 }
632
633 auto lhs_to_rhs = [&]() {
634 if (propagate_inequalities) {
635 return CollectIndirectComparisons(lhs_key.value(), rhs_key.value());
636 } else {
637 return CollectDirectComparisons(lhs_key.value(), rhs_key.value());
638 }
639 }();
640 return MergeComparisons(lhs_to_rhs, offset);
641}
642
643std::vector<TransitiveComparisonAnalyzer::Impl::Comparison>
644TransitiveComparisonAnalyzer::Impl::CollectDirectComparisons(Key lhs_key, Key rhs_key) const {
645 std::vector<Comparison> output;
646
647 auto append_known = [&](Comparison cmp) {
648 if (auto normalized = cmp.WithLHS(lhs_key)) {
649 if (normalized.value().rhs_ == rhs_key) {
650 output.push_back(normalized.value());
651 }
652 }
653 };
654
655 for (const auto& known : knowns_) {
656 append_known(known);
657 }
658 for (const auto& known : scoped_knowns_) {
659 append_known(known);
660 }
661
662 return output;
663}
664
665std::vector<TransitiveComparisonAnalyzer::Impl::Comparison>
666TransitiveComparisonAnalyzer::Impl::CollectIndirectComparisons(Key lhs_key, Key rhs_key) const {
667 auto output = DFSFromLHS(lhs_key, rhs_key);
668 for (Comparison cmp : DFSFromLHS(rhs_key, lhs_key)) {
669 auto opt_normalized = cmp.WithLHS(lhs_key);
670 ICHECK(opt_normalized.has_value());
671 output.push_back(opt_normalized.value());
672 }
673 return output;
674}
675
676std::vector<TransitiveComparisonAnalyzer::Impl::Comparison>
677TransitiveComparisonAnalyzer::Impl::DFSFromLHS(Key lhs_key, Key rhs_key) const {
678 // Everything in `to_visit` has lhs as its lhs.
679 std::unordered_set<Key> seen;
680 std::unordered_set<Key> to_visit;
681 std::unordered_map<Key, std::vector<Comparison>> compared_to_lhs;
682
683 // Utility function to add a new known statement
684 auto declare_known = [&](Comparison cmp) {
685 std::vector<Comparison>& knowns = compared_to_lhs[cmp.rhs_];
686
687 // The comparison adds no new information, no modification
688 // required.
689 for (auto& prev_known : knowns) {
690 if (prev_known.Implies(cmp)) {
691 return;
692 }
693 }
694
695 // New information may require visiting a new expression.
696 if (cmp.rhs_ != rhs_key && !seen.count(cmp.rhs_)) {
697 to_visit.insert(cmp.rhs_);
698 seen.insert(cmp.rhs_);
699 }
700
701 // This comparison is a stronger version of a previous constraint.
702 // Therefore, replace the old version entirely.
703 for (auto& prev_known : knowns) {
704 if (cmp.Implies(prev_known)) {
705 prev_known = cmp;
706 return;
707 }
708 }
709
710 // Neither a superset nor a subset of previously known
711 // constraints, must be tracked separately.
712 knowns.push_back(cmp);
713 };
714
715 // Initialize the search based on any known (in)equalities that use
716 // the LHS of the comparison.
717 for (const auto& known : knowns_) {
718 if (auto normalized = known.WithLHS(lhs_key)) {
719 declare_known(normalized.value());
720 }
721 }
722 for (const auto& known : scoped_knowns_) {
723 if (auto normalized = known.WithLHS(lhs_key)) {
724 declare_known(normalized.value());
725 }
726 }
727
728 // Walk through the space of all comparisons that can be made with
729 // LHS.
730 while (to_visit.size()) {
731 Key middle_key = *to_visit.begin();
732 to_visit.erase(to_visit.begin());
733
734 std::vector<Comparison>& prev_knowns_using_middle = compared_to_lhs.at(middle_key);
735 ICHECK(compared_to_lhs.count(middle_key));
736
737 std::vector<Comparison> new_knowns_using_lhs;
738
739 auto attempt_transitive = [&](Comparison cmp) {
740 ICHECK(cmp.IsNormalized());
741
742 Key right_key = cmp.rhs_;
743
744 if (right_key == lhs_key) {
745 return;
746 }
747
748 for (const auto& prev : prev_knowns_using_middle) {
749 CompareResult new_result = CompareResult::kUnknown;
750 int64_t new_offset = prev.offset_ + cmp.offset_;
751
752 if (prev.result_ == CompareResult::kEQ) {
753 // x == y + c1 && y OP z + c2, x OP z + (c1 + c2)
754 new_result = cmp.result_;
755 } else if (cmp.result_ == CompareResult::kEQ) {
756 // x OP y + c1 && y == z + c2, x OP z + (c1 + c2)
757 new_result = prev.result_;
758 } else if (prev.result_ == cmp.result_ &&
759 (prev.result_ == CompareResult::kLE || prev.result_ == CompareResult::kGE)) {
760 // x <= y + c1 && y <= z + c2, x <= z + (c1 + c2)
761 // x >= y + c1 && y >= z + c2, x >= z + (c1 + c2)
762 //
763 // This condition is much simpler to write than the
764 // equivalent handling of < or of >, which is why the
765 // inequalities are normalized to <= and to >=. See
766 // `TransitiveComparisonAnalyzer::Impl::Comparison::Comparison`
767 // for further details.
768 new_result = prev.result_;
769 }
770
771 if (new_result != CompareResult::kUnknown) {
772 Comparison new_known(lhs_key, right_key, new_offset, new_result);
773 new_knowns_using_lhs.push_back(new_known);
774 }
775 }
776 };
777
778 // Attempt to prove a new comparison using one of the original
779 // known comparisons. We want to find a known such that
780 // `(LHS OP1 middle) && (middle OP2 right)` can be simplified
781 // into `(LHS OP3 right)`.
782 //
783 // Note: The right side is this step is not necessarily the RHS of
784 // the comparison we're trying to prove, as we may need to find
785 // intermediate comparisons first. For example, if we know that
786 // `a<=b`, `b<=c`, and `c<=d`, and we wish to prove that `a<=d`,
787 // we must first combine `a<=b` and `b<=c` into `a<=c`. During
788 // this first step, `b` is the "middle" and `c` is the "right".
789 // The next step can then combind `a<=c` and `c<=d` into `a<=d`.
790 for (const auto& known : knowns_) {
791 if (auto cmp = known.WithLHS(middle_key)) {
792 attempt_transitive(cmp.value());
793 }
794 }
795
796 for (const auto& known : scoped_knowns_) {
797 if (auto cmp = known.WithLHS(middle_key)) {
798 attempt_transitive(cmp.value());
799 }
800 }
801
802 // Collect together all new knowns, marking new nodes for visiting
803 // as needed.
804 for (const auto& new_known : new_knowns_using_lhs) {
805 declare_known(new_known);
806 }
807 }
808
809 if (auto it = compared_to_lhs.find(rhs_key); it != compared_to_lhs.end()) {
810 return it->second;
811 } else {
812 // There are known comparisons involving the LHS and the RHS, but
813 // no path that connects the two expressions.
814 return {};
815 }
816}
817
818CompareResult TransitiveComparisonAnalyzer::Impl::MergeComparisons(
819 const std::vector<Comparison>& lhs_to_rhs, int64_t offset) const {
820 // Just because we found a comparison involving LHS and RHS doesn't
821 // mean that it's useful. e.g. Knowing that `x < y` doesn't let us
822 // prove whether `x + 5 < y`.
823 CompareResult result = CompareResult::kUnknown;
824 for (const auto& cmp : lhs_to_rhs) {
825 switch (cmp.result_) {
826 case CompareResult::kInconsistent:
827 result = CompareResult::kInconsistent;
828 break;
829
830 case CompareResult::kEQ:
831 if (offset == cmp.offset_) {
832 result = result & CompareResult::kEQ;
833 } else {
834 result = result & CompareResult::kNE;
835 }
836 break;
837
838 case CompareResult::kLE:
839 if (cmp.offset_ < offset) {
840 result = result & CompareResult::kLT;
841 } else if (cmp.offset_ <= offset) {
842 result = result & CompareResult::kLE;
843 }
844 break;
845
846 case CompareResult::kGE:
847 if (cmp.offset_ > offset) {
848 result = result & CompareResult::kGT;
849 } else if (cmp.offset_ >= offset) {
850 result = result & CompareResult::kGE;
851 }
852 break;
853
854 case CompareResult::kNE:
855 if (offset == cmp.offset_) {
856 result = result & CompareResult::kNE;
857 }
858 break;
859
860 case CompareResult::kUnknown:
861 break;
862
863 case CompareResult::kGT:
864 case CompareResult::kLT:
865 LOG(FATAL) << "Internal error, normalized comparisons should only include <= and >=";
866
867 default:
868 LOG(FATAL) << "Invalid CompareResult: " << static_cast<int>(cmp.result_);
869 }
870 }
871
872 return result;
873}
874
875} // namespace arith
876} // namespace tvm
877