1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file detect_linear_equation.cc
22 * \brief Utility to detect patterns in the expression.
23 */
24#include <tvm/arith/analyzer.h>
25#include <tvm/runtime/registry.h>
26#include <tvm/tir/analysis.h>
27#include <tvm/tir/expr.h>
28#include <tvm/tir/expr_functor.h>
29#include <tvm/tir/op.h>
30#include <tvm/tir/stmt_functor.h>
31
32namespace tvm {
33namespace arith {
34
35using namespace tir;
36
37// Linear equation, the components can be undefined.
38struct LinearEqEntry {
39 PrimExpr base;
40 PrimExpr coeff;
41};
42
43struct IntervalEntry {
44 PrimExpr min_value;
45 PrimExpr max_value;
46};
47
48class LinearEqDetector : public ExprFunctor<LinearEqEntry(const PrimExpr&, const PrimExpr&)> {
49 public:
50 explicit LinearEqDetector(Var var) : var_(var) {}
51
52 bool Detect(const PrimExpr& e, LinearEqEntry* ret) {
53 *ret = VisitExpr(e, e);
54 if (fail_) return false;
55 if (!ret->base.defined()) {
56 ret->base = make_zero(var_.dtype());
57 }
58 if (!ret->coeff.defined()) {
59 ret->coeff = make_zero(var_.dtype());
60 }
61 return true;
62 }
63
64 LinearEqEntry VisitExpr_(const AddNode* op, const PrimExpr& e) final {
65 if (fail_) return LinearEqEntry();
66 LinearEqEntry a = VisitExpr(op->a, op->a);
67 LinearEqEntry b = VisitExpr(op->b, op->b);
68 LinearEqEntry ret;
69 ret.base = AddCombine(a.base, b.base);
70 ret.coeff = AddCombine(a.coeff, b.coeff);
71 return ret;
72 }
73
74 LinearEqEntry VisitExpr_(const SubNode* op, const PrimExpr& e) final {
75 if (fail_) return LinearEqEntry();
76 LinearEqEntry a = VisitExpr(op->a, op->a);
77 LinearEqEntry b = VisitExpr(op->b, op->b);
78 LinearEqEntry ret;
79 ret.base = SubCombine(a.base, b.base);
80 ret.coeff = SubCombine(a.coeff, b.coeff);
81 return ret;
82 }
83
84 LinearEqEntry VisitExpr_(const MulNode* op, const PrimExpr& e) final {
85 if (fail_) return LinearEqEntry();
86 LinearEqEntry a = VisitExpr(op->a, op->a);
87 LinearEqEntry b = VisitExpr(op->b, op->b);
88 if (a.coeff.defined()) {
89 std::swap(a, b);
90 }
91 if (a.coeff.defined()) {
92 fail_ = true;
93 return LinearEqEntry();
94 }
95 LinearEqEntry ret;
96 ret.base = MulCombine(a.base, b.base);
97 ret.coeff = MulCombine(a.base, b.coeff);
98 return ret;
99 }
100 LinearEqEntry VisitExpr_(const VarNode* op, const PrimExpr& e) final {
101 LinearEqEntry ret;
102 if (op == var_.get()) {
103 ret.coeff = make_const(op->dtype, 1);
104 } else {
105 ret.base = e;
106 }
107 return ret;
108 }
109 LinearEqEntry VisitExprDefault_(const Object* op, const PrimExpr& e) final {
110 if (fail_) return LinearEqEntry();
111 if (UsesVar(e, [this](const VarNode* var) { return var == var_.get(); })) {
112 fail_ = true;
113 return LinearEqEntry();
114 } else {
115 LinearEqEntry ret;
116 ret.base = e;
117 return ret;
118 }
119 }
120
121 private:
122 Var var_;
123 bool fail_{false};
124 // Combine by add
125 PrimExpr AddCombine(PrimExpr a, PrimExpr b) {
126 if (!a.defined()) return b;
127 if (!b.defined()) return a;
128 return a + b;
129 }
130 PrimExpr SubCombine(PrimExpr a, PrimExpr b) {
131 // Check b first in case they are both undefined
132 if (!b.defined()) return a;
133 if (!a.defined()) return -b;
134 return a - b;
135 }
136 PrimExpr MulCombine(PrimExpr a, PrimExpr b) {
137 if (!a.defined()) return a;
138 if (!b.defined()) return b;
139 return a * b;
140 }
141};
142
143Array<PrimExpr> DetectLinearEquation(const PrimExpr& e, const Array<Var>& vars) {
144 PrimExpr base = e;
145 Array<PrimExpr> coeff;
146
147 for (Var v : vars) {
148 LinearEqEntry ret;
149 if (!LinearEqDetector(v).Detect(base, &ret)) {
150 return Array<PrimExpr>();
151 }
152 coeff.push_back(ret.coeff);
153 base = std::move(ret.base);
154 }
155
156 std::unordered_set<const VarNode*> vset;
157 auto vset_contains = [&](const VarNode* node) { return vset.count(node) != 0; };
158
159 for (size_t i = vars.size(); i > 1; --i) {
160 vset.insert(vars[i - 1].get());
161 // The previous coeff contains the variable
162 if (UsesVar(coeff[i - 2], vset_contains)) {
163 return Array<PrimExpr>();
164 }
165 }
166 coeff.push_back(base);
167 return coeff;
168}
169
170// Detect clip condition as min max value
171bool DetectClipBound(const PrimExpr& cond,
172 std::unordered_map<const VarNode*, IntervalEntry>* bmap) {
173 int flag = 0;
174 Var var;
175 auto fvisit = [&bmap, &flag, &var](const ObjectRef& n) {
176 if (const VarNode* v = n.as<VarNode>()) {
177 if (bmap->count(v)) {
178 if (flag == 0) {
179 var = Downcast<Var>(n);
180 flag = 1;
181 } else if (flag == 1) {
182 if (!var.same_as(n)) {
183 flag = -1;
184 }
185 }
186 }
187 }
188 };
189 PostOrderVisit(cond, fvisit);
190 if (flag != 1) return false;
191 // canonical form: exp >= 0
192 bool is_eq = false;
193 PrimExpr canonical;
194 if (const LTNode* op = cond.as<LTNode>()) {
195 if (!op->a.dtype().is_int()) return false;
196 canonical = op->b - op->a - make_const(op->a.dtype(), 1);
197 } else if (const LENode* op = cond.as<LENode>()) {
198 if (!op->a.dtype().is_int()) return false;
199 canonical = op->b - op->a;
200 } else if (const GTNode* op = cond.as<GTNode>()) {
201 if (!op->a.dtype().is_int()) return false;
202 canonical = op->a - op->b - make_const(op->a.dtype(), 1);
203 } else if (const GENode* op = cond.as<GENode>()) {
204 if (!op->a.dtype().is_int()) return false;
205 canonical = op->a - op->b;
206 } else if (const EQNode* op = cond.as<EQNode>()) {
207 if (!op->a.dtype().is_int()) return false;
208 canonical = op->a - op->b;
209 is_eq = true;
210 } else {
211 return false;
212 }
213 LinearEqEntry ret;
214 Analyzer analyzer;
215 if (!LinearEqDetector(var).Detect(canonical, &ret)) return false;
216 ret.coeff = analyzer.Simplify(ret.coeff);
217 IntervalEntry& p = (*bmap)[var.get()];
218
219 Optional<PrimExpr> min_value;
220 Optional<PrimExpr> max_value;
221 if (is_const_int(ret.coeff, 1)) {
222 // var + shift >=0 -> var >= -shift
223 min_value = -ret.base;
224 if (is_eq) {
225 max_value = min_value;
226 }
227 } else if (is_const_int(ret.coeff, -1)) {
228 // -var + shift >=0 -> var <= shift
229 max_value = ret.base;
230 if (is_eq) {
231 min_value = max_value;
232 }
233 }
234 if (!min_value.defined() && !max_value.defined()) {
235 return false;
236 }
237 if (min_value.defined()) {
238 if (p.min_value.defined()) {
239 p.min_value = max(p.min_value, min_value.value());
240 } else {
241 p.min_value = min_value.value();
242 }
243 }
244 if (max_value.defined()) {
245 if (p.max_value.defined()) {
246 p.max_value = min(p.max_value, max_value.value());
247 } else {
248 p.max_value = max_value.value();
249 }
250 }
251 return true;
252}
253
254template <typename OP>
255void SplitCommExpr(const PrimExpr& e, std::vector<PrimExpr>* ret) {
256 if (const OP* op = e.as<OP>()) {
257 SplitCommExpr<OP>(op->a, ret);
258 SplitCommExpr<OP>(op->b, ret);
259 } else {
260 ret->push_back(e);
261 }
262}
263
264// Detect the lower and upper bound from the expression.
265// e must be connected by and.
266Array<PrimExpr> DetectClipBound(const PrimExpr& e, const Array<Var>& vars) {
267 std::vector<PrimExpr> splits;
268 Analyzer analyzer;
269 SplitCommExpr<tir::AndNode>(analyzer.Simplify(e), &splits);
270 std::unordered_map<const VarNode*, IntervalEntry> rmap;
271 for (Var v : vars) {
272 rmap[v.get()] = IntervalEntry();
273 }
274 for (PrimExpr cond : splits) {
275 if (!DetectClipBound(cond, &rmap)) return Array<PrimExpr>();
276 }
277 Array<PrimExpr> ret;
278 for (Var v : vars) {
279 IntervalEntry e = rmap[v.get()];
280 if (e.min_value.defined()) {
281 e.min_value = analyzer.Simplify(e.min_value);
282 }
283 if (e.max_value.defined()) {
284 e.max_value = analyzer.Simplify(e.max_value);
285 }
286 ret.push_back(e.min_value);
287 ret.push_back(e.max_value);
288 }
289 return ret;
290}
291
292TVM_REGISTER_GLOBAL("arith.DetectLinearEquation").set_body_typed(DetectLinearEquation);
293
294TVM_REGISTER_GLOBAL("arith.DetectClipBound")
295 .set_body_typed([](const PrimExpr& e, const Array<Var>& vars) {
296 return DetectClipBound(e, vars);
297 });
298} // namespace arith
299} // namespace tvm
300