1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file rewrite_simplify.h
22 * \brief Rewrite-rule based simplification.
23 */
24#ifndef TVM_ARITH_REWRITE_SIMPLIFY_H_
25#define TVM_ARITH_REWRITE_SIMPLIFY_H_
26
27#include <tvm/arith/analyzer.h>
28#include <tvm/tir/op.h>
29
30#include <unordered_map>
31#include <vector>
32
33#include "const_fold.h"
34#include "ir_mutator_with_analyzer.h"
35#include "pattern_match.h"
36
37namespace tvm {
38namespace arith {
39
40using namespace tir;
41
42/*!
43 * \brief Rewrite-based simplifier.
44 *
45 * This class can be inheritated for other simplifiers.
46 */
47class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
48 public:
49 using IRMutatorWithAnalyzer::VisitExpr_;
50
51 explicit Impl(Analyzer* parent) : IRMutatorWithAnalyzer(parent) {}
52
53 void Update(const Var& var, const PrimExpr& info, bool override_info);
54 PrimExpr VisitExpr_(const AddNode* op) override;
55 PrimExpr VisitExpr_(const SubNode* op) override;
56 PrimExpr VisitExpr_(const MulNode* op) override;
57 PrimExpr VisitExpr_(const DivNode* op) override;
58 PrimExpr VisitExpr_(const ModNode* op) override;
59 PrimExpr VisitExpr_(const FloorDivNode* op) override;
60 PrimExpr VisitExpr_(const FloorModNode* op) override;
61 PrimExpr VisitExpr_(const MinNode* op) override;
62 PrimExpr VisitExpr_(const MaxNode* op) override;
63 PrimExpr VisitExpr_(const EQNode* op) override;
64 PrimExpr VisitExpr_(const NENode* op) override;
65 PrimExpr VisitExpr_(const LTNode* op) override;
66 PrimExpr VisitExpr_(const LENode* op) override;
67 PrimExpr VisitExpr_(const GTNode* op) override;
68 PrimExpr VisitExpr_(const GENode* op) override;
69 PrimExpr VisitExpr_(const AndNode* op) override;
70 PrimExpr VisitExpr_(const OrNode* op) override;
71 PrimExpr VisitExpr_(const NotNode* op) override;
72 PrimExpr VisitExpr_(const SelectNode* op) override;
73 PrimExpr VisitExpr_(const CallNode* op) override;
74 PrimExpr VisitExpr_(const VarNode* op) override;
75 PrimExpr VisitExpr_(const CastNode* op) override;
76 PrimExpr VisitExpr_(const LetNode* op) override;
77
78 std::function<void()> EnterConstraint(const PrimExpr& constraint);
79
80 /*! \brief Enable an optional extension or extensions
81 *
82 * \param flags A bitwise OR of all optional extensions that should
83 * be enabled.
84 */
85 void SetEnabledExtensions(Extension flags);
86
87 /*! \brief Return the currently enabled extensions */
88 Extension GetEnabledExtensions() const;
89
90 protected:
91 // counter to record recursive rewrite depth.
92 int recur_depth_{0};
93 // internal variable map
94 std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> var_map_;
95
96 std::vector<PrimExpr> literal_constraints_;
97
98 // Optionally enabled extensions
99 Extension enabled_extensions_{kNone};
100
101 /*! Whether the simplifier is current
102 */
103 bool recursively_visiting_boolean_{false};
104
105 // maximum number of recursion allowed during a single pass.
106 static const constexpr int kMaxRecurDepth = 5;
107
108 /*!
109 * \brief try to compare x against val.
110 * \param x The expression to be evaluated.
111 * \param val The constant value.
112 * \return comparison result.
113 */
114 CompareResult TryCompare(const PrimExpr& x, int64_t val);
115
116 /*! Try to compare x against y
117 *
118 * \param x The lhs of the comparison
119 * \param y The rhs of the comparison
120 * \return comparison result.
121 */
122 CompareResult TryCompare(const PrimExpr& x, const PrimExpr& y);
123
124 /*!
125 * \brief Internal function to check whether or not to inline let.
126 * \param op The let expr.
127 * \return The inline decision.
128 */
129 bool CanInlineLet(const LetNode* op);
130
131 /*! \brief Internal function to apply constraints
132 *
133 * Tests whether the expression is known to be true or false based
134 * on existing constraints. If the expression or its negation
135 * matches a constraint, return the boolean it should be replaced
136 * with. Otherwise, return false.
137 */
138 Optional<PrimExpr> TryMatchLiteralConstraint(const PrimExpr& expr) const;
139
140 /*! \brief Rewrite rules for Less Than comparisons
141 *
142 * These are separate from the VisitExpr_(const LTNode*) method, as
143 * they may required from rewrites of LT or LE.
144 */
145 PrimExpr ApplyRewriteRules(LT node);
146
147 /*! \brief Rewrite rules for Equal comparisons
148 *
149 * These are separate from the VisitExpr_(const EQNode*) method, as
150 * they may required from rewrites of LE or NE.
151 */
152 PrimExpr ApplyRewriteRules(EQ node);
153
154 /*! \brief Rewrite rules for Equal comparisons
155 *
156 * These are separate from the VisitExpr_(const EQNode*) method, as
157 * they may required from rewrites of LT, LE, or NE.
158 */
159 PrimExpr ApplyRewriteRules(Not node);
160
161 private:
162 CompareResult TryCompareUsingKnownInequalities(const PrimExpr& x, const PrimExpr& y);
163 CompareResult TryCompareUsingConstIntBounds(const PrimExpr& x, const PrimExpr y);
164
165 // Whether x >= val
166 bool CanProveGreaterEqual(const PrimExpr& x, int64_t val) {
167 return analyzer_->CanProveGreaterEqual(x, val);
168 }
169 // Whether x < val
170 bool CanProveLess(const PrimExpr& x, int64_t val) { return analyzer_->CanProveLess(x, val); }
171 // Whether x == val
172 bool CanProveEqual(const PrimExpr& x, int64_t val) {
173 // TODO(tqchen) refer back to super-analyzer.
174 return TryCompare(x, val) == CompareResult::kEQ;
175 }
176
177 // Recursive rewrite x
178 // we limit maximum depth of recursive rewrite allowed to
179 // avoid infinite loop
180 PrimExpr RecursiveRewrite(const PrimExpr& x) {
181 if (recur_depth_ >= kMaxRecurDepth) return x;
182 ++recur_depth_;
183 PrimExpr res = this->VisitExpr(x);
184 --recur_depth_;
185 return res;
186 }
187
188 template <typename TA>
189 PConstWithTypeLike<TA> ZeroWithTypeLike(const Pattern<TA>& pattern) {
190 return PConstWithTypeLike<TA>(pattern.derived(), 0);
191 }
192
193 template <typename TA>
194 PConstWithTypeLike<TA> OneWithTypeLike(const Pattern<TA>& pattern) {
195 return PConstWithTypeLike<TA>(pattern.derived(), 1);
196 }
197};
198
199} // namespace arith
200} // namespace tvm
201#endif // TVM_ARITH_REWRITE_SIMPLIFY_H_
202