1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/arith/analyzer.h
22 * \brief Algebra expression simplifications.
23 */
24#ifndef TVM_ARITH_ANALYZER_H_
25#define TVM_ARITH_ANALYZER_H_
26
27#include <tvm/arith/int_set.h>
28#include <tvm/ir/expr.h>
29#include <tvm/support/with.h>
30
31#include <limits>
32#include <memory>
33#include <unordered_map>
34#include <vector>
35
36namespace tvm {
37/*! \brief namespace of arithmetic analysis. */
38namespace arith {
39//-------------------------------------------------------
40// Base integer analysis API.
41//
42// We have multiple type of analyzers to do relaxed
43// integer set analysis(bound analysis, modulo) and
44// equivalence checking and simplification.
45//
46// Importantly, each analyzer may need result from
47// another analyzer.
48//-------------------------------------------------------
49
50// Forward declare Analyzer
51class Analyzer;
52
53using tir::Var;
54
55enum DivMode {
56 /*! \brief Truncated division. */
57 kTruncDiv,
58 /*! \brief Floor division. */
59 kFloorDiv
60};
61
62/*!
63 * \brief Constant integer up and lower bound(inclusive).
64 * Useful for value bound analysis.
65 *
66 * set = [min_value, max_value]
67 */
68class ConstIntBoundNode : public Object {
69 public:
70 int64_t min_value;
71 int64_t max_value;
72
73 void VisitAttrs(tvm::AttrVisitor* v) {
74 v->Visit("min_value", &min_value);
75 v->Visit("max_value", &max_value);
76 }
77
78 bool SEqualReduce(const ConstIntBoundNode* other, SEqualReducer equal) const {
79 return equal(min_value, other->min_value) && equal(max_value, other->max_value);
80 }
81
82 /*! \brief Number to represent +inf */
83 static const constexpr int64_t kPosInf = std::numeric_limits<int64_t>::max();
84 /*!
85 * \brief Number to represent -inf
86 * \note We can make use the of fact that -kPosInf == kNegInf in the project.
87 */
88 static const constexpr int64_t kNegInf = -kPosInf;
89
90 static constexpr const char* _type_key = "arith.ConstIntBound";
91 TVM_DECLARE_FINAL_OBJECT_INFO(ConstIntBoundNode, Object);
92};
93
94/*!
95 * \brief reference class to ConstIntBoundNode
96 * \sa ConstIntBoundNode
97 */
98class ConstIntBound : public ObjectRef {
99 public:
100 /*!
101 * \brief constructor by fields.
102 * \param min_value The mininum value.
103 * \param max_value The maximum value.
104 */
105 TVM_DLL ConstIntBound(int64_t min_value, int64_t max_value);
106
107 static const constexpr int64_t kPosInf = ConstIntBoundNode::kPosInf;
108 static const constexpr int64_t kNegInf = ConstIntBoundNode::kNegInf;
109 TVM_DEFINE_OBJECT_REF_METHODS(ConstIntBound, ObjectRef, ConstIntBoundNode);
110};
111
112/*!
113 * \brief Analyzer to get constant integer bound over expression.
114 */
115class ConstIntBoundAnalyzer {
116 public:
117 using BoundMapType = std::unordered_map<PrimExpr, ConstIntBound, ObjectPtrHash, ObjectPtrEqual>;
118 /*!
119 * \brief analyze the expr
120 * \param expr The expression of interest.
121 * \return the result of the analysis.
122 */
123 TVM_DLL ConstIntBound operator()(const PrimExpr& expr) const;
124
125 /*!
126 * \brief analyze the expr with the intermediate memorized to avoid redundant computation
127 * \param expr The expression of interest.
128 * \param bound The lookup table to store the intermediate results
129 * \return the result of the analysis.
130 */
131 TVM_DLL ConstIntBound operator()(const PrimExpr& expr, BoundMapType* bound);
132
133 /*!
134 * \brief Update constant int bound information of var.
135 *
136 * \param var The variable of interest.
137 * \param info The bound information.
138 * \param allow_override whether we allow override of existing information.
139 */
140 TVM_DLL void Update(const Var& var, const ConstIntBound& info, bool allow_override = false);
141 /*!
142 * \brief Bind variable to a range.
143 *
144 * \param var The variable.
145 * \param range The range we bind to.
146 * \param allow_override Whether we allow overriding an existing var's range.
147 */
148 TVM_DLL void Bind(const Var& var, const Range& range, bool allow_override = false);
149
150 private:
151 friend class Analyzer;
152 friend class ConstraintContext;
153 explicit ConstIntBoundAnalyzer(Analyzer* parent);
154 TVM_DLL ~ConstIntBoundAnalyzer();
155 /*!
156 * \brief Update the internal state to enter constraint.
157 * \param constraint A constraint expression.
158 *
159 * \return an exit function that must be called to cleanup the constraint can be nullptr.
160 */
161 std::function<void()> EnterConstraint(const PrimExpr& constraint);
162 struct Entry;
163 class Impl;
164 /*! \brief Internal impl */
165 Impl* impl_;
166};
167
168/*!
169 * \brief Range of a linear integer function.
170 * Use to do specify the possible index values.
171 *
172 * set = { coeff * x + base | x in Z }
173 *
174 * When coeff != 0, it can also be written as
175 * set = { n | n % coeff == base }
176 *
177 * This is useful to decide if the index is dividable by certain value.
178 * For example, if index = 0 + 4 x, then we know it can be divided by 4.
179 */
180class ModularSetNode : public Object {
181 public:
182 /*! \brief linear co-efficient */
183 int64_t coeff;
184 /*! \brief The base */
185 int64_t base;
186
187 void VisitAttrs(tvm::AttrVisitor* v) {
188 v->Visit("coeff", &coeff);
189 v->Visit("base", &base);
190 }
191
192 bool SEqualReduce(const ModularSetNode* other, SEqualReducer equal) const {
193 return equal(coeff, other->coeff) && equal(base, other->base);
194 }
195
196 static constexpr const char* _type_key = "arith.ModularSet";
197 TVM_DECLARE_FINAL_OBJECT_INFO(ModularSetNode, Object);
198};
199
200/*!
201 * \brief reference of ModularSetNode
202 * \sa ModularSetNode
203 */
204class ModularSet : public ObjectRef {
205 public:
206 TVM_DLL ModularSet(int64_t coeff, int64_t base);
207
208 TVM_DEFINE_OBJECT_REF_METHODS(ModularSet, ObjectRef, ModularSetNode);
209};
210
211/*!
212 * \brief Analyzer to get modular information over expression.
213 */
214class ModularSetAnalyzer {
215 public:
216 /*!
217 * \brief analyze the expr
218 * \param expr The expression of interest.
219 * \return the result of the analysis.
220 */
221 TVM_DLL ModularSet operator()(const PrimExpr& expr);
222 /*!
223 * \brief Update constant int bound information of var.
224 *
225 * \param var The variable of interest.
226 * \param info The bound information.
227 * \param allow_override whether we allow override of existing information.
228 */
229 TVM_DLL void Update(const Var& var, const ModularSet& info, bool allow_override = false);
230
231 private:
232 friend class Analyzer;
233 friend class ConstraintContext;
234 explicit ModularSetAnalyzer(Analyzer* parent);
235 TVM_DLL ~ModularSetAnalyzer();
236 /*!
237 * \brief Update the internal state to enter constraint.
238 * \param constraint A constraint expression.
239 *
240 * \return an exit function that must be called to cleanup the constraint can be nullptr.
241 */
242 std::function<void()> EnterConstraint(const PrimExpr& constraint);
243 struct Entry;
244 class Impl;
245 /*! \brief Internal impl */
246 Impl* impl_;
247};
248
249/*!
250 * \brief Rewrite-rule based simplifier.
251 */
252class RewriteSimplifier {
253 public:
254 /*!
255 * \brief analyze the expr
256 * \param expr The expression of interest.
257 * \return the result of the analysis.
258 */
259 TVM_DLL PrimExpr operator()(const PrimExpr& expr);
260
261 /*!
262 * \brief Update binding of var to a new expression.
263 *
264 * \param var The variable of interest.
265 * \param new_expr
266 * \param allow_override Whether we allow override of existing information.
267 */
268 TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool allow_override = false);
269
270 /*!
271 * \brief Update the internal state to enter constraint.
272 * \param constraint A constraint expression.
273 *
274 * \return an exit function that must be called to cleanup the constraint can be nullptr.
275 */
276 TVM_DLL std::function<void()> EnterConstraint(const PrimExpr& constraint);
277
278 /*! \brief Flags to enable more computationally-intensive simplifications
279 *
280 * These simplifications may be required for specific schedules, but
281 * would impose too high a compile-time cost to enable by default.
282 * They can be enabled on an as-needed basis by calling
283 * `RewriteSimplifier::SetEnabledExtensions` prior to using
284 * `RewriteSimplifier::operator()`.
285 *
286 * Flags are defined as powers of two to allow future expansion. To
287 * enable multiple extensions, a user should pass a bitwise OR of the
288 * flags for each desired extension.
289 */
290 enum Extension {
291 // No extensions enabled
292 kNone = 0,
293
294 /* When simplifying an inequality, attempt to use scope-based knowns.
295 *
296 * Example:
297 * if_then_else(i<j && j<k, i<k, false) => if_then_else(i<j && j<k, true, false)
298 */
299 kTransitivelyProveInequalities = (1 << 0),
300
301 /* When simplifying a boolean expression, convert to an AND of ORs
302 * (conjunctive normal form).
303 *
304 * Example:
305 * (a && b) || c => (a || c) && (b || c)
306 */
307 kConvertBooleanToAndOfOrs = (1 << 1),
308
309 /* When simplifying a boolean AND or a boolean OR, simplify each
310 * branch under the assumption that the other branch does not
311 * already dominate the result. That is, simplify each branch of
312 * (A && B) under the assumption that the other branch is true,
313 * and simplify each branch of (A || B) under the assumption that
314 * the other branch is false.
315 *
316 * Example:
317 * (n < 10) && (n < 5) => (n < 10)
318 * (n < 10) || (n < 5) => (n < 5)
319 */
320 kApplyConstraintsToBooleanBranches = (1 << 2),
321 };
322
323 /*! \brief Enable an optional extension or extensions
324 *
325 * \param flags A bitwise OR of all optional extensions that should
326 * be enabled.
327 */
328 TVM_DLL void SetEnabledExtensions(Extension flags);
329
330 /*! \brief Return the currently enabled extensions */
331 TVM_DLL Extension GetEnabledExtensions() const;
332
333 private:
334 friend class Analyzer;
335 friend class ConstraintContext;
336 friend class CanonicalSimplifier;
337 explicit RewriteSimplifier(Analyzer* parent);
338 TVM_DLL ~RewriteSimplifier();
339 class Impl;
340 /*! \brief Internal impl */
341 Impl* impl_;
342};
343
344/*!
345 * \brief Canonical-form based simplifier.
346 */
347class CanonicalSimplifier {
348 public:
349 /*!
350 * \brief analyze the expr
351 * \param expr The expression of interest.
352 * \return the result of the analysis.
353 */
354 TVM_DLL PrimExpr operator()(const PrimExpr& expr);
355
356 /*!
357 * \brief Update binding of var to a new expression.
358 *
359 * \param var The variable of interest.
360 * \param new_expr
361 * \param allow_override whether we allow override of existing information.
362 */
363 TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool allow_override = false);
364
365 private:
366 friend class Analyzer;
367 friend class ConstraintContext;
368 explicit CanonicalSimplifier(Analyzer* parent);
369 TVM_DLL ~CanonicalSimplifier();
370 class Impl;
371 /*! \brief Internal impl */
372 Impl* impl_;
373};
374
375/*! \brief Structure for representing result of known
376 *
377 * Values are assigned to allow these flags to be used in bitwise
378 * operations.
379 */
380enum class CompareResult : int {
381 kInconsistent = 0,
382 kEQ = 1,
383 kLT = 2,
384 kLE = 3,
385 kGT = 4,
386 kGE = 5,
387 kNE = 6,
388 kUnknown = 7
389};
390
391inline constexpr CompareResult operator&(CompareResult lhs, CompareResult rhs) {
392 return CompareResult(static_cast<int>(lhs) & static_cast<int>(rhs));
393}
394inline constexpr CompareResult operator|(CompareResult lhs, CompareResult rhs) {
395 return CompareResult(static_cast<int>(lhs) | static_cast<int>(rhs));
396}
397
398/*!
399 * \brief Using previously specified knowns, compare the expressions provided
400 *
401 * Given known expressions [(a OP b), (b OP c), ..., (y OP z)], search
402 * for a known result for `(a OP z)`.
403 */
404class TransitiveComparisonAnalyzer {
405 public:
406 /* \brief Using previously specified knowns, compare the expressions provided
407 *
408 * \param lhs The left-hand side of the comparison
409 *
410 * \param rhs The right-hand side of the comparison
411 *
412 * \param propagate_inequalities If true, attempt to find a sequence
413 * of transitive inequalities that allow the lhs and rhs to be
414 * compared. If false, only use the known comparison that have been
415 * directly provided. Using `propagate_inequalities = false` is
416 * roughly equivalent to comparing against all known inequality
417 * expressions using `ExprDeepEqual`, but also allows for constant
418 * offsets on either side of the inequality.
419 *
420 * \return The most specific result that can be proven about the
421 * comparison. If nothing can be proven, returns kUnknown.
422 */
423 TVM_DLL CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs,
424 bool propagate_inequalities = true);
425
426 /*! \brief Bind a variable as being equal to a known expression
427 *
428 * \param var The variable of interest.
429 * \param expr The bound expression
430 * \param allow_override Whether to allow override of existing information.
431 */
432 TVM_DLL void Bind(const Var& var, const PrimExpr& expr, bool allow_override = false);
433
434 /*! \brief Bind a variable as being within a specified range
435 *
436 * \param var The variable of interest.
437 * \param range The known range
438 * \param allow_override Whether to allow override of existing information.
439 */
440 TVM_DLL void Bind(const Var& var, const Range& range, bool allow_override = false);
441
442 /*!
443 * \brief Update the internal state to enter constraint.
444 * \param constraint A constraint expression.
445 *
446 * \return an exit function that must be called to cleanup the constraint can be nullptr.
447 */
448 TVM_DLL std::function<void()> EnterConstraint(const PrimExpr& constraint);
449
450 private:
451 friend class Analyzer;
452 friend class ConstraintContext;
453 TransitiveComparisonAnalyzer();
454 TVM_DLL ~TransitiveComparisonAnalyzer();
455 class Impl;
456 /*! \brief Internal impl */
457 std::unique_ptr<Impl> impl_;
458};
459
460/*!
461 * \brief Constraint context.
462 *
463 * \code
464 *
465 * Var("x");
466 * arith::Analyzer analyzer;
467 * {
468 * With<arith::ConstraintContext> scope(&analyzer, x % 3 == 0);
469 * ICHECK_EQ(analyzer.modular_set(x)->coeff, 3);
470 * }
471 * // constraint no longer in effect.
472 * ICHECK_NE(analyzer.modular_set(x)->coeff, 3);
473 *
474 * \endcode
475 */
476class ConstraintContext {
477 private:
478 // declare friend to enable with.
479 friend class With<ConstraintContext>;
480 /*!
481 * \brief Construct a constraint context.
482 * \param analyzer The analyzer.
483 * \param constraint The constraint to be applied.
484 */
485 ConstraintContext(Analyzer* analyzer, PrimExpr constraint)
486 : analyzer_(analyzer), constraint_(constraint) {}
487 // enter the scope.
488 void EnterWithScope();
489 // exit the scope.
490 void ExitWithScope();
491 /*! \brief The analyzer */
492 Analyzer* analyzer_;
493 /*! \brief The constraint */
494 PrimExpr constraint_;
495 /*! \brief function to be called in recovery */
496 std::vector<std::function<void()>> recovery_functions_;
497};
498
499/*!
500 * \brief Integer set analyzer.
501 */
502class IntSetAnalyzer {
503 public:
504 /*!
505 * \brief Find a symbolic integer set that contains all possible values of
506 * expr given the domain of each variables.
507 *
508 * \param expr The expression of interest.
509 * \param dom_map The domain map to indicate which variable to relax.
510 * \return the result of the analysis.
511 */
512 TVM_DLL IntSet operator()(const PrimExpr& expr, const Map<Var, IntSet>& dom_map);
513
514 /*!
515 * \brief Find a symbolic integer set that contains all possible
516 * values of expr given the domain of each variables, using
517 * the domain map defined by bound variables.
518 *
519 * \param expr The expression of interest.
520 * \return the result of the analysis.
521 */
522 TVM_DLL IntSet operator()(const PrimExpr& expr);
523
524 /*!
525 * \brief Update binding of var to a new expression.
526 *
527 * \param var The variable of interest.
528 * \param new_interval_set The set of allowed values for this var.
529 * \param allow_override whether we allow override of existing information.
530 */
531 TVM_DLL void Update(const Var& var, const IntSet& new_interval_set, bool allow_override = false);
532
533 /*!
534 * \brief Update binding of var to a new expression.
535 *
536 * \param var The variable of interest.
537 * \param new_range The range of allowed values for this var.
538 * \param allow_override whether we allow override of existing information.
539 */
540 TVM_DLL void Bind(const Var& var, const Range& new_range, bool allow_override = false);
541
542 std::function<void()> EnterConstraint(const PrimExpr& constraint);
543
544 private:
545 friend class Analyzer;
546 explicit IntSetAnalyzer(Analyzer* parent);
547 TVM_DLL ~IntSetAnalyzer();
548 class Impl;
549 /*! \brief Internal impl */
550 Impl* impl_;
551};
552
553/*!
554 * \brief Analyzer that contains bunch of sub-analyzers.
555 *
556 * Each sub-analyzer can make use of another sub-analyzer
557 * by weak reference of this.
558 *
559 * NOTE for sub-analyzer developers:
560 * If the analyzer uses memoization, we need to clear the internal
561 * cache when information about a Var has been overridden.
562 */
563class TVM_DLL Analyzer {
564 public:
565 /*
566 * Disable copy constructor.
567 */
568 Analyzer(const Analyzer&) = delete;
569 Analyzer& operator=(const Analyzer&) = delete;
570 /*! \brief sub-analyzer: const integer bound */
571 ConstIntBoundAnalyzer const_int_bound;
572 /*! \brief sub-analyzer: modular set */
573 ModularSetAnalyzer modular_set;
574 /*! \brief sub-analyzer rewrite simplify */
575 RewriteSimplifier rewrite_simplify;
576 /*! \brief sub-analyzer canonical simplify */
577 CanonicalSimplifier canonical_simplify;
578 /*! \brief sub-analyzer: int set */
579 IntSetAnalyzer int_set;
580 /*! \brief sub-analyzer transitive comparisons */
581 TransitiveComparisonAnalyzer transitive_comparisons;
582 /*! \brief constructor */
583 Analyzer();
584 /*!
585 * \brief Notify all the sub-analyzers that var
586 * is created and binded to expr.
587 *
588 * Each var can only be bound once.
589 *
590 * \param var The variable.
591 * \param expr The expression we bind to.
592 * \param allow_override Whether we allow overriding an existing var's
593 * expression. This option should not be used if there is any dependency
594 * between variables.
595 */
596 void Bind(const Var& var, const PrimExpr& expr, bool allow_override = false);
597 /*!
598 * \brief Notify all the sub-analyzers that var
599 * is created and binded to a range.
600 *
601 * Each var can only be binded once.
602 *
603 * \param var The variable.
604 * \param range The range we bind to.
605 * \param allow_override Whether we allow overriding an existing var's
606 * expression. This option should not be used if there is any dependency
607 * between variables.
608 */
609 void Bind(const Var& var, const Range& range, bool allow_override = false);
610 /*!
611 * \brief Bind all the vars in the Map
612 *
613 * \param variables The {variable -> range} map.
614 * \param allow_override Whether we allow overriding an existing var's
615 * expression. This option should not be used if there is any dependency
616 * between variables.
617 */
618 void Bind(const Map<Var, Range>& variables, bool allow_override = false);
619 /*!
620 * \brief Whether can we prove expr >= val.
621
622 * Non-negative proof is very useful in integer analysis
623 * to lower divisions and mods given difference in trunc and ceil mode.
624 *
625 * \param expr The expression.
626 * \param lower_bound The lower bound.
627 * \return Whether we can prove it.
628 *
629 * \note Analyzer will call into sub-analyzers to get the result.
630 */
631 bool CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound);
632 /*!
633 * \brief Whether can we prove expr < val.
634
635 * Non-negative proof is very useful in integer analysis
636 * to lower divisions and mods given difference in trunc and ceil mode.
637 *
638 * \param expr The expression.
639 * \param upper_bound The upper bound.
640 * \return Whether we can prove it.
641 *
642 * \note Analyzer will call into sub-analyzers to get the result.
643 */
644 bool CanProveLess(const PrimExpr& expr, int64_t upper_bound);
645 /*!
646 * \brief Whether can we prove lhs == rhs.
647 *
648 * \param lhs The input lhs.
649 * \param rhs The input rhs.
650 * \return Whether we can prove lhs == rhs.
651 *
652 * \note Analyzer will call into sub-analyzers to get the result.
653 */
654 bool CanProveEqual(const PrimExpr& lhs, const PrimExpr& rhs);
655 /*!
656 * \brief Whether can we prove condition.
657 *
658 * \param cond The expression to be proved.
659 * \return The result.
660 *
661 * \note Analyzer will call into sub-analyzers to get the result.
662 */
663 bool CanProve(const PrimExpr& cond);
664 /*!
665 * \brief Simplify expr.
666 *
667 * \param expr The expression to be simplified.
668 * \param steps The simplification runs in the order of
669 * rewrite_simplify (step 1) -> canonical_simplify (step 2) ->
670 * rewrite_simplify (step 3) -> canonical_simplify (step 4) -> ...
671 * param steps controls how many steps to run.
672 * Default is 2, i.e., rewrite_simplify + canonical_simplify.
673 * \return The result.
674 *
675 * \note Analyzer will call into sub-analyzers to get the result.
676 */
677 PrimExpr Simplify(const PrimExpr& expr, int steps = 2);
678};
679
680} // namespace arith
681} // namespace tvm
682#endif // TVM_ARITH_ANALYZER_H_
683