1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file rewrite_simplify.h |
22 | * \brief Rewrite-rule based simplification. |
23 | */ |
24 | #ifndef TVM_ARITH_REWRITE_SIMPLIFY_H_ |
25 | #define TVM_ARITH_REWRITE_SIMPLIFY_H_ |
26 | |
27 | #include <tvm/arith/analyzer.h> |
28 | #include <tvm/tir/op.h> |
29 | |
30 | #include <unordered_map> |
31 | #include <vector> |
32 | |
33 | #include "const_fold.h" |
34 | #include "ir_mutator_with_analyzer.h" |
35 | #include "pattern_match.h" |
36 | |
37 | namespace tvm { |
38 | namespace arith { |
39 | |
40 | using namespace tir; |
41 | |
42 | /*! |
43 | * \brief Rewrite-based simplifier. |
44 | * |
45 | * This class can be inheritated for other simplifiers. |
46 | */ |
47 | class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { |
48 | public: |
49 | using IRMutatorWithAnalyzer::VisitExpr_; |
50 | |
51 | explicit Impl(Analyzer* parent) : IRMutatorWithAnalyzer(parent) {} |
52 | |
53 | void Update(const Var& var, const PrimExpr& info, bool override_info); |
54 | PrimExpr VisitExpr_(const AddNode* op) override; |
55 | PrimExpr VisitExpr_(const SubNode* op) override; |
56 | PrimExpr VisitExpr_(const MulNode* op) override; |
57 | PrimExpr VisitExpr_(const DivNode* op) override; |
58 | PrimExpr VisitExpr_(const ModNode* op) override; |
59 | PrimExpr VisitExpr_(const FloorDivNode* op) override; |
60 | PrimExpr VisitExpr_(const FloorModNode* op) override; |
61 | PrimExpr VisitExpr_(const MinNode* op) override; |
62 | PrimExpr VisitExpr_(const MaxNode* op) override; |
63 | PrimExpr VisitExpr_(const EQNode* op) override; |
64 | PrimExpr VisitExpr_(const NENode* op) override; |
65 | PrimExpr VisitExpr_(const LTNode* op) override; |
66 | PrimExpr VisitExpr_(const LENode* op) override; |
67 | PrimExpr VisitExpr_(const GTNode* op) override; |
68 | PrimExpr VisitExpr_(const GENode* op) override; |
69 | PrimExpr VisitExpr_(const AndNode* op) override; |
70 | PrimExpr VisitExpr_(const OrNode* op) override; |
71 | PrimExpr VisitExpr_(const NotNode* op) override; |
72 | PrimExpr VisitExpr_(const SelectNode* op) override; |
73 | PrimExpr VisitExpr_(const CallNode* op) override; |
74 | PrimExpr VisitExpr_(const VarNode* op) override; |
75 | PrimExpr VisitExpr_(const CastNode* op) override; |
76 | PrimExpr VisitExpr_(const LetNode* op) override; |
77 | |
78 | std::function<void()> EnterConstraint(const PrimExpr& constraint); |
79 | |
80 | /*! \brief Enable an optional extension or extensions |
81 | * |
82 | * \param flags A bitwise OR of all optional extensions that should |
83 | * be enabled. |
84 | */ |
85 | void SetEnabledExtensions(Extension flags); |
86 | |
87 | /*! \brief Return the currently enabled extensions */ |
88 | Extension GetEnabledExtensions() const; |
89 | |
90 | protected: |
91 | // counter to record recursive rewrite depth. |
92 | int recur_depth_{0}; |
93 | // internal variable map |
94 | std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> var_map_; |
95 | |
96 | std::vector<PrimExpr> literal_constraints_; |
97 | |
98 | // Optionally enabled extensions |
99 | Extension enabled_extensions_{kNone}; |
100 | |
101 | /*! Whether the simplifier is current |
102 | */ |
103 | bool recursively_visiting_boolean_{false}; |
104 | |
105 | // maximum number of recursion allowed during a single pass. |
106 | static const constexpr int kMaxRecurDepth = 5; |
107 | |
108 | /*! |
109 | * \brief try to compare x against val. |
110 | * \param x The expression to be evaluated. |
111 | * \param val The constant value. |
112 | * \return comparison result. |
113 | */ |
114 | CompareResult TryCompare(const PrimExpr& x, int64_t val); |
115 | |
116 | /*! Try to compare x against y |
117 | * |
118 | * \param x The lhs of the comparison |
119 | * \param y The rhs of the comparison |
120 | * \return comparison result. |
121 | */ |
122 | CompareResult TryCompare(const PrimExpr& x, const PrimExpr& y); |
123 | |
124 | /*! |
125 | * \brief Internal function to check whether or not to inline let. |
126 | * \param op The let expr. |
127 | * \return The inline decision. |
128 | */ |
129 | bool CanInlineLet(const LetNode* op); |
130 | |
131 | /*! \brief Internal function to apply constraints |
132 | * |
133 | * Tests whether the expression is known to be true or false based |
134 | * on existing constraints. If the expression or its negation |
135 | * matches a constraint, return the boolean it should be replaced |
136 | * with. Otherwise, return false. |
137 | */ |
138 | Optional<PrimExpr> TryMatchLiteralConstraint(const PrimExpr& expr) const; |
139 | |
140 | /*! \brief Rewrite rules for Less Than comparisons |
141 | * |
142 | * These are separate from the VisitExpr_(const LTNode*) method, as |
143 | * they may required from rewrites of LT or LE. |
144 | */ |
145 | PrimExpr ApplyRewriteRules(LT node); |
146 | |
147 | /*! \brief Rewrite rules for Equal comparisons |
148 | * |
149 | * These are separate from the VisitExpr_(const EQNode*) method, as |
150 | * they may required from rewrites of LE or NE. |
151 | */ |
152 | PrimExpr ApplyRewriteRules(EQ node); |
153 | |
154 | /*! \brief Rewrite rules for Equal comparisons |
155 | * |
156 | * These are separate from the VisitExpr_(const EQNode*) method, as |
157 | * they may required from rewrites of LT, LE, or NE. |
158 | */ |
159 | PrimExpr ApplyRewriteRules(Not node); |
160 | |
161 | private: |
162 | CompareResult TryCompareUsingKnownInequalities(const PrimExpr& x, const PrimExpr& y); |
163 | CompareResult TryCompareUsingConstIntBounds(const PrimExpr& x, const PrimExpr y); |
164 | |
165 | // Whether x >= val |
166 | bool CanProveGreaterEqual(const PrimExpr& x, int64_t val) { |
167 | return analyzer_->CanProveGreaterEqual(x, val); |
168 | } |
169 | // Whether x < val |
170 | bool CanProveLess(const PrimExpr& x, int64_t val) { return analyzer_->CanProveLess(x, val); } |
171 | // Whether x == val |
172 | bool CanProveEqual(const PrimExpr& x, int64_t val) { |
173 | // TODO(tqchen) refer back to super-analyzer. |
174 | return TryCompare(x, val) == CompareResult::kEQ; |
175 | } |
176 | |
177 | // Recursive rewrite x |
178 | // we limit maximum depth of recursive rewrite allowed to |
179 | // avoid infinite loop |
180 | PrimExpr RecursiveRewrite(const PrimExpr& x) { |
181 | if (recur_depth_ >= kMaxRecurDepth) return x; |
182 | ++recur_depth_; |
183 | PrimExpr res = this->VisitExpr(x); |
184 | --recur_depth_; |
185 | return res; |
186 | } |
187 | |
188 | template <typename TA> |
189 | PConstWithTypeLike<TA> ZeroWithTypeLike(const Pattern<TA>& pattern) { |
190 | return PConstWithTypeLike<TA>(pattern.derived(), 0); |
191 | } |
192 | |
193 | template <typename TA> |
194 | PConstWithTypeLike<TA> OneWithTypeLike(const Pattern<TA>& pattern) { |
195 | return PConstWithTypeLike<TA>(pattern.derived(), 1); |
196 | } |
197 | }; |
198 | |
199 | } // namespace arith |
200 | } // namespace tvm |
201 | #endif // TVM_ARITH_REWRITE_SIMPLIFY_H_ |
202 | |