1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file detect_linear_equation.cc |
22 | * \brief Utility to detect patterns in the expression. |
23 | */ |
24 | #include <tvm/arith/analyzer.h> |
25 | #include <tvm/runtime/registry.h> |
26 | #include <tvm/tir/analysis.h> |
27 | #include <tvm/tir/expr.h> |
28 | #include <tvm/tir/expr_functor.h> |
29 | #include <tvm/tir/op.h> |
30 | #include <tvm/tir/stmt_functor.h> |
31 | |
32 | namespace tvm { |
33 | namespace arith { |
34 | |
35 | using namespace tir; |
36 | |
37 | // Linear equation, the components can be undefined. |
38 | struct LinearEqEntry { |
39 | PrimExpr base; |
40 | PrimExpr coeff; |
41 | }; |
42 | |
43 | struct IntervalEntry { |
44 | PrimExpr min_value; |
45 | PrimExpr max_value; |
46 | }; |
47 | |
48 | class LinearEqDetector : public ExprFunctor<LinearEqEntry(const PrimExpr&, const PrimExpr&)> { |
49 | public: |
50 | explicit LinearEqDetector(Var var) : var_(var) {} |
51 | |
52 | bool Detect(const PrimExpr& e, LinearEqEntry* ret) { |
53 | *ret = VisitExpr(e, e); |
54 | if (fail_) return false; |
55 | if (!ret->base.defined()) { |
56 | ret->base = make_zero(var_.dtype()); |
57 | } |
58 | if (!ret->coeff.defined()) { |
59 | ret->coeff = make_zero(var_.dtype()); |
60 | } |
61 | return true; |
62 | } |
63 | |
64 | LinearEqEntry VisitExpr_(const AddNode* op, const PrimExpr& e) final { |
65 | if (fail_) return LinearEqEntry(); |
66 | LinearEqEntry a = VisitExpr(op->a, op->a); |
67 | LinearEqEntry b = VisitExpr(op->b, op->b); |
68 | LinearEqEntry ret; |
69 | ret.base = AddCombine(a.base, b.base); |
70 | ret.coeff = AddCombine(a.coeff, b.coeff); |
71 | return ret; |
72 | } |
73 | |
74 | LinearEqEntry VisitExpr_(const SubNode* op, const PrimExpr& e) final { |
75 | if (fail_) return LinearEqEntry(); |
76 | LinearEqEntry a = VisitExpr(op->a, op->a); |
77 | LinearEqEntry b = VisitExpr(op->b, op->b); |
78 | LinearEqEntry ret; |
79 | ret.base = SubCombine(a.base, b.base); |
80 | ret.coeff = SubCombine(a.coeff, b.coeff); |
81 | return ret; |
82 | } |
83 | |
84 | LinearEqEntry VisitExpr_(const MulNode* op, const PrimExpr& e) final { |
85 | if (fail_) return LinearEqEntry(); |
86 | LinearEqEntry a = VisitExpr(op->a, op->a); |
87 | LinearEqEntry b = VisitExpr(op->b, op->b); |
88 | if (a.coeff.defined()) { |
89 | std::swap(a, b); |
90 | } |
91 | if (a.coeff.defined()) { |
92 | fail_ = true; |
93 | return LinearEqEntry(); |
94 | } |
95 | LinearEqEntry ret; |
96 | ret.base = MulCombine(a.base, b.base); |
97 | ret.coeff = MulCombine(a.base, b.coeff); |
98 | return ret; |
99 | } |
100 | LinearEqEntry VisitExpr_(const VarNode* op, const PrimExpr& e) final { |
101 | LinearEqEntry ret; |
102 | if (op == var_.get()) { |
103 | ret.coeff = make_const(op->dtype, 1); |
104 | } else { |
105 | ret.base = e; |
106 | } |
107 | return ret; |
108 | } |
109 | LinearEqEntry VisitExprDefault_(const Object* op, const PrimExpr& e) final { |
110 | if (fail_) return LinearEqEntry(); |
111 | if (UsesVar(e, [this](const VarNode* var) { return var == var_.get(); })) { |
112 | fail_ = true; |
113 | return LinearEqEntry(); |
114 | } else { |
115 | LinearEqEntry ret; |
116 | ret.base = e; |
117 | return ret; |
118 | } |
119 | } |
120 | |
121 | private: |
122 | Var var_; |
123 | bool fail_{false}; |
124 | // Combine by add |
125 | PrimExpr AddCombine(PrimExpr a, PrimExpr b) { |
126 | if (!a.defined()) return b; |
127 | if (!b.defined()) return a; |
128 | return a + b; |
129 | } |
130 | PrimExpr SubCombine(PrimExpr a, PrimExpr b) { |
131 | // Check b first in case they are both undefined |
132 | if (!b.defined()) return a; |
133 | if (!a.defined()) return -b; |
134 | return a - b; |
135 | } |
136 | PrimExpr MulCombine(PrimExpr a, PrimExpr b) { |
137 | if (!a.defined()) return a; |
138 | if (!b.defined()) return b; |
139 | return a * b; |
140 | } |
141 | }; |
142 | |
143 | Array<PrimExpr> DetectLinearEquation(const PrimExpr& e, const Array<Var>& vars) { |
144 | PrimExpr base = e; |
145 | Array<PrimExpr> coeff; |
146 | |
147 | for (Var v : vars) { |
148 | LinearEqEntry ret; |
149 | if (!LinearEqDetector(v).Detect(base, &ret)) { |
150 | return Array<PrimExpr>(); |
151 | } |
152 | coeff.push_back(ret.coeff); |
153 | base = std::move(ret.base); |
154 | } |
155 | |
156 | std::unordered_set<const VarNode*> vset; |
157 | auto vset_contains = [&](const VarNode* node) { return vset.count(node) != 0; }; |
158 | |
159 | for (size_t i = vars.size(); i > 1; --i) { |
160 | vset.insert(vars[i - 1].get()); |
161 | // The previous coeff contains the variable |
162 | if (UsesVar(coeff[i - 2], vset_contains)) { |
163 | return Array<PrimExpr>(); |
164 | } |
165 | } |
166 | coeff.push_back(base); |
167 | return coeff; |
168 | } |
169 | |
170 | // Detect clip condition as min max value |
171 | bool DetectClipBound(const PrimExpr& cond, |
172 | std::unordered_map<const VarNode*, IntervalEntry>* bmap) { |
173 | int flag = 0; |
174 | Var var; |
175 | auto fvisit = [&bmap, &flag, &var](const ObjectRef& n) { |
176 | if (const VarNode* v = n.as<VarNode>()) { |
177 | if (bmap->count(v)) { |
178 | if (flag == 0) { |
179 | var = Downcast<Var>(n); |
180 | flag = 1; |
181 | } else if (flag == 1) { |
182 | if (!var.same_as(n)) { |
183 | flag = -1; |
184 | } |
185 | } |
186 | } |
187 | } |
188 | }; |
189 | PostOrderVisit(cond, fvisit); |
190 | if (flag != 1) return false; |
191 | // canonical form: exp >= 0 |
192 | bool is_eq = false; |
193 | PrimExpr canonical; |
194 | if (const LTNode* op = cond.as<LTNode>()) { |
195 | if (!op->a.dtype().is_int()) return false; |
196 | canonical = op->b - op->a - make_const(op->a.dtype(), 1); |
197 | } else if (const LENode* op = cond.as<LENode>()) { |
198 | if (!op->a.dtype().is_int()) return false; |
199 | canonical = op->b - op->a; |
200 | } else if (const GTNode* op = cond.as<GTNode>()) { |
201 | if (!op->a.dtype().is_int()) return false; |
202 | canonical = op->a - op->b - make_const(op->a.dtype(), 1); |
203 | } else if (const GENode* op = cond.as<GENode>()) { |
204 | if (!op->a.dtype().is_int()) return false; |
205 | canonical = op->a - op->b; |
206 | } else if (const EQNode* op = cond.as<EQNode>()) { |
207 | if (!op->a.dtype().is_int()) return false; |
208 | canonical = op->a - op->b; |
209 | is_eq = true; |
210 | } else { |
211 | return false; |
212 | } |
213 | LinearEqEntry ret; |
214 | Analyzer analyzer; |
215 | if (!LinearEqDetector(var).Detect(canonical, &ret)) return false; |
216 | ret.coeff = analyzer.Simplify(ret.coeff); |
217 | IntervalEntry& p = (*bmap)[var.get()]; |
218 | |
219 | Optional<PrimExpr> min_value; |
220 | Optional<PrimExpr> max_value; |
221 | if (is_const_int(ret.coeff, 1)) { |
222 | // var + shift >=0 -> var >= -shift |
223 | min_value = -ret.base; |
224 | if (is_eq) { |
225 | max_value = min_value; |
226 | } |
227 | } else if (is_const_int(ret.coeff, -1)) { |
228 | // -var + shift >=0 -> var <= shift |
229 | max_value = ret.base; |
230 | if (is_eq) { |
231 | min_value = max_value; |
232 | } |
233 | } |
234 | if (!min_value.defined() && !max_value.defined()) { |
235 | return false; |
236 | } |
237 | if (min_value.defined()) { |
238 | if (p.min_value.defined()) { |
239 | p.min_value = max(p.min_value, min_value.value()); |
240 | } else { |
241 | p.min_value = min_value.value(); |
242 | } |
243 | } |
244 | if (max_value.defined()) { |
245 | if (p.max_value.defined()) { |
246 | p.max_value = min(p.max_value, max_value.value()); |
247 | } else { |
248 | p.max_value = max_value.value(); |
249 | } |
250 | } |
251 | return true; |
252 | } |
253 | |
254 | template <typename OP> |
255 | void SplitCommExpr(const PrimExpr& e, std::vector<PrimExpr>* ret) { |
256 | if (const OP* op = e.as<OP>()) { |
257 | SplitCommExpr<OP>(op->a, ret); |
258 | SplitCommExpr<OP>(op->b, ret); |
259 | } else { |
260 | ret->push_back(e); |
261 | } |
262 | } |
263 | |
264 | // Detect the lower and upper bound from the expression. |
265 | // e must be connected by and. |
266 | Array<PrimExpr> DetectClipBound(const PrimExpr& e, const Array<Var>& vars) { |
267 | std::vector<PrimExpr> splits; |
268 | Analyzer analyzer; |
269 | SplitCommExpr<tir::AndNode>(analyzer.Simplify(e), &splits); |
270 | std::unordered_map<const VarNode*, IntervalEntry> rmap; |
271 | for (Var v : vars) { |
272 | rmap[v.get()] = IntervalEntry(); |
273 | } |
274 | for (PrimExpr cond : splits) { |
275 | if (!DetectClipBound(cond, &rmap)) return Array<PrimExpr>(); |
276 | } |
277 | Array<PrimExpr> ret; |
278 | for (Var v : vars) { |
279 | IntervalEntry e = rmap[v.get()]; |
280 | if (e.min_value.defined()) { |
281 | e.min_value = analyzer.Simplify(e.min_value); |
282 | } |
283 | if (e.max_value.defined()) { |
284 | e.max_value = analyzer.Simplify(e.max_value); |
285 | } |
286 | ret.push_back(e.min_value); |
287 | ret.push_back(e.max_value); |
288 | } |
289 | return ret; |
290 | } |
291 | |
292 | TVM_REGISTER_GLOBAL("arith.DetectLinearEquation" ).set_body_typed(DetectLinearEquation); |
293 | |
294 | TVM_REGISTER_GLOBAL("arith.DetectClipBound" ) |
295 | .set_body_typed([](const PrimExpr& e, const Array<Var>& vars) { |
296 | return DetectClipBound(e, vars); |
297 | }); |
298 | } // namespace arith |
299 | } // namespace tvm |
300 | |