1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file bound_deducer.cc |
22 | * \brief Utility to deduce bound of expression |
23 | */ |
24 | #include <tvm/arith/analyzer.h> |
25 | #include <tvm/runtime/registry.h> |
26 | #include <tvm/tir/expr.h> |
27 | #include <tvm/tir/expr_functor.h> |
28 | |
29 | #include <unordered_map> |
30 | #include <unordered_set> |
31 | |
32 | #include "interval_set.h" |
33 | |
34 | namespace tvm { |
35 | namespace arith { |
36 | |
37 | using namespace tir; |
38 | |
39 | // a visitor to find the path to the target variable |
40 | // from a expression. |
41 | class VariablePathFinder : public ExprVisitor { |
42 | public: |
43 | explicit VariablePathFinder(PrimExpr target) : target_(target) {} |
44 | |
45 | void VisitExpr(const PrimExpr& node) final { |
46 | if (visited_.count(node.get()) != 0) return; |
47 | visited_.insert(node.get()); |
48 | |
49 | if (!found_) path_.push_back(node.get()); |
50 | if (node.same_as(target_)) found_ = true; |
51 | ExprVisitor::VisitExpr(node); |
52 | if (!found_) path_.pop_back(); |
53 | } |
54 | |
55 | std::vector<const Object*> path_; |
56 | |
57 | private: |
58 | bool found_{false}; |
59 | PrimExpr target_; |
60 | std::unordered_set<const Object*> visited_; |
61 | }; |
62 | |
63 | // get the path to the variable, |
64 | // return empty vector to represent failure |
65 | std::vector<const Object*> GetPath(PrimExpr target, PrimExpr expr) { |
66 | VariablePathFinder v(target); |
67 | v(expr); |
68 | return v.path_; |
69 | } |
70 | |
71 | enum CompareOp { kGreater, kLess, kEqual }; |
72 | |
73 | // a visitor to deduce the bound of a variable from a expression |
74 | class BoundDeducer : public ExprFunctor<void(const PrimExpr&)> { |
75 | public: |
76 | friend class BoundDeduceInputChecker; |
77 | friend class Converter; |
78 | BoundDeducer(PrimExpr target, PrimExpr expr, |
79 | const std::unordered_map<const VarNode*, IntSet>& hint_map, |
80 | const std::unordered_map<const VarNode*, IntSet>& relax_map) |
81 | : target_(target), expr_(expr), hint_map_(hint_map), relax_map_(relax_map) {} |
82 | |
83 | void Deduce(); |
84 | |
85 | void VisitExpr(const PrimExpr& e) final { |
86 | if (!success_) return; |
87 | if (iter_ < path_.size() && e.get() == path_[iter_++]) { |
88 | ExprFunctor::VisitExpr(e); |
89 | } else { |
90 | success_ = false; |
91 | return; |
92 | } |
93 | } |
94 | |
95 | void VisitExprDefault_(const Object* op) final { success_ = false; } |
96 | |
97 | SignType GetSignType(const PrimExpr& e) { |
98 | if (e.dtype().is_uint()) { |
99 | return kPositive; |
100 | } |
101 | return expr_map_[e].GetSignType(); |
102 | } |
103 | |
104 | void VisitExpr_(const VarNode* op) final {} |
105 | |
106 | void VisitExpr_(const AddNode* op) final { |
107 | bool left = op->a.get() == path_[iter_]; |
108 | result_ -= left ? op->b : op->a; |
109 | this->VisitExpr(left ? op->a : op->b); |
110 | } |
111 | |
112 | void VisitExpr_(const SubNode* op) final { |
113 | bool left = op->a.get() == path_[iter_]; |
114 | if (left) { |
115 | result_ += op->b; |
116 | } else { |
117 | result_ -= op->a; |
118 | result_ = -result_; |
119 | comp_op = ReverseOp(comp_op); |
120 | } |
121 | this->VisitExpr(left ? op->a : op->b); |
122 | } |
123 | |
124 | void VisitExpr_(const MulNode* op) final { |
125 | bool left = op->a.get() == path_[iter_]; |
126 | PrimExpr operand = left ? op->b : op->a; |
127 | PrimExpr target_var = left ? op->a : op->b; |
128 | |
129 | SignType sign_operand = GetSignType(operand); |
130 | if (sign_operand == SignType::kNegative) { |
131 | comp_op = ReverseOp(comp_op); |
132 | } else if (sign_operand == SignType::kUnknown) { |
133 | // unable to get the sign of operand |
134 | success_ = false; |
135 | return; |
136 | } |
137 | |
138 | // always use relax bound |
139 | bool divided = analyzer_.CanProve(floormod(result_, operand) == 0); |
140 | |
141 | result_ = floordiv(result_, operand); // rounding down here |
142 | |
143 | if (!divided) { |
144 | if (comp_op == kGreater) { |
145 | // System will round down in all the cases, so add one for result_ for kGreater |
146 | // (x >= 3/2 --> x >= 2) |
147 | // (x >= -3/2 --> x >= -1) |
148 | // (x >= 3/-2 --> x >= -1) |
149 | // (x >= -3/-2 --> x >= 2) |
150 | result_ += 1; |
151 | } else if (comp_op == kEqual) { |
152 | // condition unsatisfiable as with floor div, it will change the expression |
153 | success_ = false; |
154 | return; |
155 | } else { |
156 | // System rounds down in all cases, do nothing for kLess. |
157 | // ( x <= 3/2 --> x <= 1) |
158 | // ( x <= -3/2 --> x <= -2) |
159 | // ( x <= 3/-2 --> x <= -2) |
160 | // ( x <= -3/-2 --> x <= 1) |
161 | } |
162 | } |
163 | this->VisitExpr(left ? op->a : op->b); |
164 | } |
165 | |
166 | void VisitExpr_(const FloorDivNode* op) final { |
167 | if (op->b.get() == path_[iter_]) { |
168 | // Skip cases where the var is divisor. |
169 | success_ = false; |
170 | return; |
171 | } |
172 | PrimExpr divisor = op->b; |
173 | if (analyzer_.CanProveEqual(divisor, 0)) { |
174 | // Skip zero divisor |
175 | success_ = false; |
176 | return; |
177 | } |
178 | |
179 | SignType sign_operand = GetSignType(divisor); |
180 | if (sign_operand == SignType::kNegative) { |
181 | comp_op = ReverseOp(comp_op); |
182 | divisor = -divisor; |
183 | result_ = -result_; |
184 | } else if (sign_operand == SignType::kUnknown) { |
185 | // unable to get the sign of operand |
186 | success_ = false; |
187 | return; |
188 | } |
189 | |
190 | if (comp_op == kGreater) { |
191 | // (x // 6 >= 4 --> x >= 4 * 6) |
192 | result_ = result_ * divisor; |
193 | } else if (comp_op == kEqual) { |
194 | // The bound is not single directional |
195 | // (x // 6 == 4 --> 30 > x >= 24) |
196 | // TODO(@wrongtest): support bidirectional bound |
197 | success_ = false; |
198 | return; |
199 | } else { |
200 | // (x // 6 <= 4 --> x <= 4 * 6 + 5) |
201 | result_ = result_ * divisor + divisor - 1; |
202 | } |
203 | if (sign_operand == SignType::kNegative) { |
204 | // (x // -6 >= 4 --> -((x + 6 - 1) // 6) >= 4 |
205 | // --> (x + 6 - 1) // 6 <= -4 |
206 | result_ = result_ - divisor + 1; |
207 | } |
208 | |
209 | this->VisitExpr(op->a); |
210 | } |
211 | |
212 | PrimExpr result_; |
213 | CompareOp comp_op{kGreater}; |
214 | bool success_{true}; |
215 | |
216 | private: |
217 | void Init(); |
218 | void Transform(); |
219 | void Relax(); |
220 | CompareOp ReverseOp(CompareOp comp_op); |
221 | PrimExpr target_; |
222 | PrimExpr expr_; |
223 | const std::unordered_map<const VarNode*, IntSet>& hint_map_; |
224 | const std::unordered_map<const VarNode*, IntSet>& relax_map_; |
225 | ExprIntSetMap expr_map_; |
226 | std::vector<const Object*> path_; |
227 | size_t iter_{0}; |
228 | // internal analzyer |
229 | Analyzer analyzer_; |
230 | }; |
231 | |
232 | class BoundDeduceInputChecker : public ExprVisitor { |
233 | public: |
234 | bool Check(BoundDeducer* deducer) { |
235 | deducer_ = deducer; |
236 | this->VisitExpr(deducer_->expr_); |
237 | return target_count == 1; |
238 | } |
239 | |
240 | void VisitExpr(const PrimExpr& e) final { |
241 | if (e.same_as(deducer_->target_)) ++target_count; |
242 | ExprVisitor::VisitExpr(e); |
243 | } |
244 | |
245 | private: |
246 | BoundDeducer* deducer_; |
247 | size_t target_count{0}; |
248 | }; |
249 | |
250 | void BoundDeducer::Init() { |
251 | BoundDeduceInputChecker checker; |
252 | if (!checker.Check(this)) success_ = false; |
253 | Transform(); |
254 | } |
255 | |
256 | CompareOp BoundDeducer::ReverseOp(CompareOp comp_op) { |
257 | switch (comp_op) { |
258 | case kEqual: |
259 | return kEqual; // IntSet can not represent range for `NE |
260 | case kGreater: |
261 | return kLess; |
262 | case kLess: |
263 | return kGreater; |
264 | default: |
265 | LOG(FATAL) << "Not a valid compare op" ; |
266 | } |
267 | } |
268 | |
269 | void BoundDeducer::Transform() { |
270 | // We will ensure to set expr_ such that it contains target_ |
271 | if (const LTNode* op = expr_.as<LTNode>()) { |
272 | if (GetPath(target_, op->a).empty()) { |
273 | // a < b -> b >= a + 1 |
274 | comp_op = kGreater; |
275 | expr_ = op->b; |
276 | result_ = op->a + 1; |
277 | } else { |
278 | // a < b -> a <= b - 1 |
279 | comp_op = kLess; |
280 | expr_ = op->a; |
281 | result_ = op->b - 1; |
282 | } |
283 | } else if (const LENode* op = expr_.as<LENode>()) { |
284 | if (GetPath(target_, op->a).empty()) { |
285 | // a <= b -> b >= a |
286 | comp_op = kGreater; |
287 | expr_ = op->b; |
288 | result_ = op->a; |
289 | } else { |
290 | comp_op = kLess; |
291 | expr_ = op->a; |
292 | result_ = op->b; |
293 | } |
294 | } else if (const GTNode* op = expr_.as<GTNode>()) { |
295 | if (GetPath(target_, op->a).empty()) { |
296 | // a > b -> b <= a - 1 |
297 | comp_op = kLess; |
298 | expr_ = op->b; |
299 | result_ = op->a - 1; |
300 | } else { |
301 | // a > b -> a >= b + 1 |
302 | comp_op = kGreater; |
303 | expr_ = op->a; |
304 | result_ = op->b + 1; |
305 | } |
306 | } else if (const GENode* op = expr_.as<GENode>()) { |
307 | if (GetPath(target_, op->a).empty()) { |
308 | // a >= b -> b <= a |
309 | comp_op = kLess; |
310 | expr_ = op->b; |
311 | result_ = op->a; |
312 | } else { |
313 | comp_op = kGreater; |
314 | expr_ = op->a; |
315 | result_ = op->b; |
316 | } |
317 | } else if (const EQNode* op = expr_.as<EQNode>()) { |
318 | comp_op = kEqual; |
319 | if (GetPath(target_, op->a).empty()) { |
320 | // if the b == a -> a == b |
321 | expr_ = op->b; |
322 | result_ = op->a; |
323 | } else { |
324 | expr_ = op->a; |
325 | result_ = op->b; |
326 | } |
327 | } else { |
328 | success_ = false; |
329 | } |
330 | } |
331 | |
332 | void BoundDeducer::Deduce() { |
333 | Init(); |
334 | if (!success_) return; |
335 | |
336 | Relax(); |
337 | if (!success_) return; |
338 | // get the path |
339 | path_ = GetPath(target_, expr_); |
340 | if (!path_.size()) { |
341 | success_ = false; |
342 | return; |
343 | } |
344 | expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_); |
345 | |
346 | this->VisitExpr(expr_); |
347 | } |
348 | |
349 | void BoundDeducer::Relax() { |
350 | IntSet a = EvalSet(expr_, relax_map_); |
351 | IntSet b = EvalSet(result_, relax_map_); |
352 | if (a.IsEverything() || b.IsEverything()) { |
353 | success_ = false; |
354 | return; |
355 | } |
356 | // Both LHS and RHS of the EQ should behave as constants e.g. i == j, |
357 | // can not be resolved when either `i` or `j` or both are variables with |
358 | // some Range OR `i` and `j` both should be a single point in IntSet |
359 | if (comp_op == kEqual && |
360 | (!analyzer_.CanProve(b.min() == b.max()) || !analyzer_.CanProve(a.min() == a.max()))) { |
361 | success_ = false; |
362 | return; |
363 | } |
364 | expr_ = (comp_op == kGreater) ? a.min() : a.max(); |
365 | result_ = (comp_op == kGreater) ? b.max() : b.min(); |
366 | } |
367 | |
368 | IntSet DeduceBound(PrimExpr v, PrimExpr e, |
369 | const std::unordered_map<const VarNode*, IntSet>& hint_map, |
370 | const std::unordered_map<const VarNode*, IntSet>& relax_map) { |
371 | BoundDeducer d(v, e, hint_map, relax_map); |
372 | d.Deduce(); |
373 | if (!d.success_) return IntSet::Nothing(); |
374 | PrimExpr min = neg_inf(), max = pos_inf(); |
375 | if (d.comp_op == kEqual) { |
376 | min = d.result_; |
377 | max = d.result_; |
378 | } else if (d.comp_op == kGreater) { |
379 | min = d.result_; |
380 | } else { |
381 | max = d.result_; |
382 | } |
383 | return IntSet::Interval(min, max); |
384 | } |
385 | |
386 | // assuming e >= 0, deduce the bound of variable from it. |
387 | // return empty set to represent deduce failure. |
388 | IntSet DeduceBound(PrimExpr v, PrimExpr e, const Map<Var, IntSet>& hint_map, |
389 | const Map<Var, IntSet>& relax_map) { |
390 | std::unordered_map<const VarNode*, IntSet> hmap; |
391 | for (auto kv : hint_map) { |
392 | hmap[kv.first.get()] = kv.second; |
393 | } |
394 | std::unordered_map<const VarNode*, IntSet> rmap; |
395 | for (auto kv : relax_map) { |
396 | rmap[kv.first.get()] = kv.second; |
397 | } |
398 | return DeduceBound(v, e, hmap, rmap); |
399 | } |
400 | |
401 | TVM_REGISTER_GLOBAL("arith.DeduceBound" ) |
402 | .set_body_typed([](PrimExpr v, PrimExpr cond, const Map<Var, IntSet> hint_map, |
403 | const Map<Var, IntSet> relax_map) { |
404 | return DeduceBound(v, cond, hint_map, relax_map); |
405 | }); |
406 | |
407 | } // namespace arith |
408 | } // namespace tvm |
409 | |