1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file bound_deducer.cc
22 * \brief Utility to deduce bound of expression
23 */
24#include <tvm/arith/analyzer.h>
25#include <tvm/runtime/registry.h>
26#include <tvm/tir/expr.h>
27#include <tvm/tir/expr_functor.h>
28
29#include <unordered_map>
30#include <unordered_set>
31
32#include "interval_set.h"
33
34namespace tvm {
35namespace arith {
36
37using namespace tir;
38
39// a visitor to find the path to the target variable
40// from a expression.
41class VariablePathFinder : public ExprVisitor {
42 public:
43 explicit VariablePathFinder(PrimExpr target) : target_(target) {}
44
45 void VisitExpr(const PrimExpr& node) final {
46 if (visited_.count(node.get()) != 0) return;
47 visited_.insert(node.get());
48
49 if (!found_) path_.push_back(node.get());
50 if (node.same_as(target_)) found_ = true;
51 ExprVisitor::VisitExpr(node);
52 if (!found_) path_.pop_back();
53 }
54
55 std::vector<const Object*> path_;
56
57 private:
58 bool found_{false};
59 PrimExpr target_;
60 std::unordered_set<const Object*> visited_;
61};
62
63// get the path to the variable,
64// return empty vector to represent failure
65std::vector<const Object*> GetPath(PrimExpr target, PrimExpr expr) {
66 VariablePathFinder v(target);
67 v(expr);
68 return v.path_;
69}
70
71enum CompareOp { kGreater, kLess, kEqual };
72
73// a visitor to deduce the bound of a variable from a expression
74class BoundDeducer : public ExprFunctor<void(const PrimExpr&)> {
75 public:
76 friend class BoundDeduceInputChecker;
77 friend class Converter;
78 BoundDeducer(PrimExpr target, PrimExpr expr,
79 const std::unordered_map<const VarNode*, IntSet>& hint_map,
80 const std::unordered_map<const VarNode*, IntSet>& relax_map)
81 : target_(target), expr_(expr), hint_map_(hint_map), relax_map_(relax_map) {}
82
83 void Deduce();
84
85 void VisitExpr(const PrimExpr& e) final {
86 if (!success_) return;
87 if (iter_ < path_.size() && e.get() == path_[iter_++]) {
88 ExprFunctor::VisitExpr(e);
89 } else {
90 success_ = false;
91 return;
92 }
93 }
94
95 void VisitExprDefault_(const Object* op) final { success_ = false; }
96
97 SignType GetSignType(const PrimExpr& e) {
98 if (e.dtype().is_uint()) {
99 return kPositive;
100 }
101 return expr_map_[e].GetSignType();
102 }
103
104 void VisitExpr_(const VarNode* op) final {}
105
106 void VisitExpr_(const AddNode* op) final {
107 bool left = op->a.get() == path_[iter_];
108 result_ -= left ? op->b : op->a;
109 this->VisitExpr(left ? op->a : op->b);
110 }
111
112 void VisitExpr_(const SubNode* op) final {
113 bool left = op->a.get() == path_[iter_];
114 if (left) {
115 result_ += op->b;
116 } else {
117 result_ -= op->a;
118 result_ = -result_;
119 comp_op = ReverseOp(comp_op);
120 }
121 this->VisitExpr(left ? op->a : op->b);
122 }
123
124 void VisitExpr_(const MulNode* op) final {
125 bool left = op->a.get() == path_[iter_];
126 PrimExpr operand = left ? op->b : op->a;
127 PrimExpr target_var = left ? op->a : op->b;
128
129 SignType sign_operand = GetSignType(operand);
130 if (sign_operand == SignType::kNegative) {
131 comp_op = ReverseOp(comp_op);
132 } else if (sign_operand == SignType::kUnknown) {
133 // unable to get the sign of operand
134 success_ = false;
135 return;
136 }
137
138 // always use relax bound
139 bool divided = analyzer_.CanProve(floormod(result_, operand) == 0);
140
141 result_ = floordiv(result_, operand); // rounding down here
142
143 if (!divided) {
144 if (comp_op == kGreater) {
145 // System will round down in all the cases, so add one for result_ for kGreater
146 // (x >= 3/2 --> x >= 2)
147 // (x >= -3/2 --> x >= -1)
148 // (x >= 3/-2 --> x >= -1)
149 // (x >= -3/-2 --> x >= 2)
150 result_ += 1;
151 } else if (comp_op == kEqual) {
152 // condition unsatisfiable as with floor div, it will change the expression
153 success_ = false;
154 return;
155 } else {
156 // System rounds down in all cases, do nothing for kLess.
157 // ( x <= 3/2 --> x <= 1)
158 // ( x <= -3/2 --> x <= -2)
159 // ( x <= 3/-2 --> x <= -2)
160 // ( x <= -3/-2 --> x <= 1)
161 }
162 }
163 this->VisitExpr(left ? op->a : op->b);
164 }
165
166 void VisitExpr_(const FloorDivNode* op) final {
167 if (op->b.get() == path_[iter_]) {
168 // Skip cases where the var is divisor.
169 success_ = false;
170 return;
171 }
172 PrimExpr divisor = op->b;
173 if (analyzer_.CanProveEqual(divisor, 0)) {
174 // Skip zero divisor
175 success_ = false;
176 return;
177 }
178
179 SignType sign_operand = GetSignType(divisor);
180 if (sign_operand == SignType::kNegative) {
181 comp_op = ReverseOp(comp_op);
182 divisor = -divisor;
183 result_ = -result_;
184 } else if (sign_operand == SignType::kUnknown) {
185 // unable to get the sign of operand
186 success_ = false;
187 return;
188 }
189
190 if (comp_op == kGreater) {
191 // (x // 6 >= 4 --> x >= 4 * 6)
192 result_ = result_ * divisor;
193 } else if (comp_op == kEqual) {
194 // The bound is not single directional
195 // (x // 6 == 4 --> 30 > x >= 24)
196 // TODO(@wrongtest): support bidirectional bound
197 success_ = false;
198 return;
199 } else {
200 // (x // 6 <= 4 --> x <= 4 * 6 + 5)
201 result_ = result_ * divisor + divisor - 1;
202 }
203 if (sign_operand == SignType::kNegative) {
204 // (x // -6 >= 4 --> -((x + 6 - 1) // 6) >= 4
205 // --> (x + 6 - 1) // 6 <= -4
206 result_ = result_ - divisor + 1;
207 }
208
209 this->VisitExpr(op->a);
210 }
211
212 PrimExpr result_;
213 CompareOp comp_op{kGreater};
214 bool success_{true};
215
216 private:
217 void Init();
218 void Transform();
219 void Relax();
220 CompareOp ReverseOp(CompareOp comp_op);
221 PrimExpr target_;
222 PrimExpr expr_;
223 const std::unordered_map<const VarNode*, IntSet>& hint_map_;
224 const std::unordered_map<const VarNode*, IntSet>& relax_map_;
225 ExprIntSetMap expr_map_;
226 std::vector<const Object*> path_;
227 size_t iter_{0};
228 // internal analzyer
229 Analyzer analyzer_;
230};
231
232class BoundDeduceInputChecker : public ExprVisitor {
233 public:
234 bool Check(BoundDeducer* deducer) {
235 deducer_ = deducer;
236 this->VisitExpr(deducer_->expr_);
237 return target_count == 1;
238 }
239
240 void VisitExpr(const PrimExpr& e) final {
241 if (e.same_as(deducer_->target_)) ++target_count;
242 ExprVisitor::VisitExpr(e);
243 }
244
245 private:
246 BoundDeducer* deducer_;
247 size_t target_count{0};
248};
249
250void BoundDeducer::Init() {
251 BoundDeduceInputChecker checker;
252 if (!checker.Check(this)) success_ = false;
253 Transform();
254}
255
256CompareOp BoundDeducer::ReverseOp(CompareOp comp_op) {
257 switch (comp_op) {
258 case kEqual:
259 return kEqual; // IntSet can not represent range for `NE
260 case kGreater:
261 return kLess;
262 case kLess:
263 return kGreater;
264 default:
265 LOG(FATAL) << "Not a valid compare op";
266 }
267}
268
269void BoundDeducer::Transform() {
270 // We will ensure to set expr_ such that it contains target_
271 if (const LTNode* op = expr_.as<LTNode>()) {
272 if (GetPath(target_, op->a).empty()) {
273 // a < b -> b >= a + 1
274 comp_op = kGreater;
275 expr_ = op->b;
276 result_ = op->a + 1;
277 } else {
278 // a < b -> a <= b - 1
279 comp_op = kLess;
280 expr_ = op->a;
281 result_ = op->b - 1;
282 }
283 } else if (const LENode* op = expr_.as<LENode>()) {
284 if (GetPath(target_, op->a).empty()) {
285 // a <= b -> b >= a
286 comp_op = kGreater;
287 expr_ = op->b;
288 result_ = op->a;
289 } else {
290 comp_op = kLess;
291 expr_ = op->a;
292 result_ = op->b;
293 }
294 } else if (const GTNode* op = expr_.as<GTNode>()) {
295 if (GetPath(target_, op->a).empty()) {
296 // a > b -> b <= a - 1
297 comp_op = kLess;
298 expr_ = op->b;
299 result_ = op->a - 1;
300 } else {
301 // a > b -> a >= b + 1
302 comp_op = kGreater;
303 expr_ = op->a;
304 result_ = op->b + 1;
305 }
306 } else if (const GENode* op = expr_.as<GENode>()) {
307 if (GetPath(target_, op->a).empty()) {
308 // a >= b -> b <= a
309 comp_op = kLess;
310 expr_ = op->b;
311 result_ = op->a;
312 } else {
313 comp_op = kGreater;
314 expr_ = op->a;
315 result_ = op->b;
316 }
317 } else if (const EQNode* op = expr_.as<EQNode>()) {
318 comp_op = kEqual;
319 if (GetPath(target_, op->a).empty()) {
320 // if the b == a -> a == b
321 expr_ = op->b;
322 result_ = op->a;
323 } else {
324 expr_ = op->a;
325 result_ = op->b;
326 }
327 } else {
328 success_ = false;
329 }
330}
331
332void BoundDeducer::Deduce() {
333 Init();
334 if (!success_) return;
335
336 Relax();
337 if (!success_) return;
338 // get the path
339 path_ = GetPath(target_, expr_);
340 if (!path_.size()) {
341 success_ = false;
342 return;
343 }
344 expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_);
345
346 this->VisitExpr(expr_);
347}
348
349void BoundDeducer::Relax() {
350 IntSet a = EvalSet(expr_, relax_map_);
351 IntSet b = EvalSet(result_, relax_map_);
352 if (a.IsEverything() || b.IsEverything()) {
353 success_ = false;
354 return;
355 }
356 // Both LHS and RHS of the EQ should behave as constants e.g. i == j,
357 // can not be resolved when either `i` or `j` or both are variables with
358 // some Range OR `i` and `j` both should be a single point in IntSet
359 if (comp_op == kEqual &&
360 (!analyzer_.CanProve(b.min() == b.max()) || !analyzer_.CanProve(a.min() == a.max()))) {
361 success_ = false;
362 return;
363 }
364 expr_ = (comp_op == kGreater) ? a.min() : a.max();
365 result_ = (comp_op == kGreater) ? b.max() : b.min();
366}
367
368IntSet DeduceBound(PrimExpr v, PrimExpr e,
369 const std::unordered_map<const VarNode*, IntSet>& hint_map,
370 const std::unordered_map<const VarNode*, IntSet>& relax_map) {
371 BoundDeducer d(v, e, hint_map, relax_map);
372 d.Deduce();
373 if (!d.success_) return IntSet::Nothing();
374 PrimExpr min = neg_inf(), max = pos_inf();
375 if (d.comp_op == kEqual) {
376 min = d.result_;
377 max = d.result_;
378 } else if (d.comp_op == kGreater) {
379 min = d.result_;
380 } else {
381 max = d.result_;
382 }
383 return IntSet::Interval(min, max);
384}
385
386// assuming e >= 0, deduce the bound of variable from it.
387// return empty set to represent deduce failure.
388IntSet DeduceBound(PrimExpr v, PrimExpr e, const Map<Var, IntSet>& hint_map,
389 const Map<Var, IntSet>& relax_map) {
390 std::unordered_map<const VarNode*, IntSet> hmap;
391 for (auto kv : hint_map) {
392 hmap[kv.first.get()] = kv.second;
393 }
394 std::unordered_map<const VarNode*, IntSet> rmap;
395 for (auto kv : relax_map) {
396 rmap[kv.first.get()] = kv.second;
397 }
398 return DeduceBound(v, e, hmap, rmap);
399}
400
401TVM_REGISTER_GLOBAL("arith.DeduceBound")
402 .set_body_typed([](PrimExpr v, PrimExpr cond, const Map<Var, IntSet> hint_map,
403 const Map<Var, IntSet> relax_map) {
404 return DeduceBound(v, cond, hint_map, relax_map);
405 });
406
407} // namespace arith
408} // namespace tvm
409