1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/arith/const_int_bound.cc
22 */
23#include <tvm/arith/analyzer.h>
24#include <tvm/runtime/registry.h>
25#include <tvm/tir/builtin.h>
26#include <tvm/tir/expr_functor.h>
27
28#include <algorithm>
29
30#include "constraint_extract.h"
31#include "int_operator.h"
32#include "pattern_match.h"
33
34namespace tvm {
35namespace arith {
36
37using namespace tir;
38
39TVM_REGISTER_NODE_TYPE(ConstIntBoundNode);
40
41ConstIntBound::ConstIntBound(int64_t min_value, int64_t max_value) {
42 auto node = make_object<ConstIntBoundNode>();
43 node->min_value = min_value;
44 node->max_value = max_value;
45 data_ = std::move(node);
46}
47
48ConstIntBound MakeConstIntBound(int64_t min_value, int64_t max_value) {
49 return ConstIntBound(min_value, max_value);
50}
51
52TVM_REGISTER_GLOBAL("arith.ConstIntBound").set_body_typed(MakeConstIntBound);
53
54inline void PrintBoundValue(std::ostream& os, int64_t val) {
55 if (val == ConstIntBound::kPosInf) {
56 os << "pos_inf";
57 } else if (val == ConstIntBound::kNegInf) {
58 os << "neg_inf";
59 } else {
60 os << val;
61 }
62}
63
64TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
65 .set_dispatch<ConstIntBoundNode>([](const ObjectRef& node, ReprPrinter* p) {
66 auto* op = static_cast<const ConstIntBoundNode*>(node.get());
67 p->stream << "ConstIntBound[";
68 PrintBoundValue(p->stream, op->min_value);
69 p->stream << ',';
70 PrintBoundValue(p->stream, op->max_value);
71 p->stream << ']';
72 });
73
74// internal entry for const int bound
75struct ConstIntBoundAnalyzer::Entry {
76 int64_t min_value;
77 int64_t max_value;
78
79 bool is_const(int64_t value) const { return min_value == max_value && min_value == value; }
80
81 bool operator==(const Entry& other) const {
82 return min_value == other.min_value && max_value == other.max_value;
83 }
84};
85
86class ConstIntBoundAnalyzer::Impl
87 : public ExprFunctor<ConstIntBoundAnalyzer::Entry(const PrimExpr&)> {
88 public:
89 /*! \brief additional bound info about expr in bound */
90 struct BoundInfo {
91 /*! \brief The expr */
92 PrimExpr expr;
93 /*! \brief The additional bound */
94 Entry bound;
95
96 BoundInfo() {}
97 BoundInfo(PrimExpr expr, Entry bound) : expr(expr), bound(bound) {}
98 };
99
100 void Bind(const Var& var, const Range& range, bool allow_override) {
101 Entry a = VisitExpr(range->min);
102 Entry b = VisitExpr(range->extent);
103 Entry ret;
104 ret.min_value = a.min_value;
105 ret.max_value = InfAwareAdd(a.max_value, InfAwareAdd(b.max_value, -1));
106 Update(var, ret, allow_override);
107 }
108
109 void Update(const Var& var, const Entry& info, bool allow_override) {
110 if (!allow_override) {
111 auto it = var_map_.find(var);
112 if (it != var_map_.end()) {
113 ICHECK(it->second == info)
114 << "Trying to update var \'" << var << "\'"
115 << " with a different const bound: "
116 << "original=" << ConstIntBound(it->second.min_value, it->second.max_value)
117 << ", new=" << ConstIntBound(info.min_value, info.max_value);
118 }
119 }
120 var_map_[var] = info;
121 }
122
123 Entry VisitExpr_(const LetNode* op) final {
124 auto it = var_map_.find(op->var);
125 // if the var has not been binded, update the info.
126 if (it == var_map_.end()) {
127 var_map_[op->var] = this->VisitExpr(op->value);
128 Entry ret = VisitExpr(op->body);
129 var_map_.erase(op->var);
130 return ret;
131 } else {
132 return VisitExpr(op->body);
133 }
134 }
135
136 void Update(const Var& var, const ConstIntBound& info, bool allow_override) {
137 Update(var, MakeBound(info->min_value, info->max_value), allow_override);
138 }
139
140 // Override visitor behaviors
141 Entry VisitExprDefault_(const Object* op) final {
142 return Everything(static_cast<const PrimExprNode*>(op)->dtype);
143 }
144
145 Entry VisitExpr(const PrimExpr& expr) final {
146 Entry res = ExprFunctor::VisitExpr(expr);
147 tir::ExprDeepEqual equal;
148 // a linear search over additional info
149 // assume we won't have a lot of conditions
150 for (const BoundInfo& info : additional_info_) {
151 if (equal(expr, info.expr)) {
152 res = Intersect(res, info.bound);
153 }
154 }
155 if (bound_) {
156 auto val = bound_->find(expr);
157 if (val != bound_->end()) {
158 auto everything = Everything(expr->dtype);
159 ICHECK(
160 (val->second->min_value == res.min_value && val->second->max_value == res.max_value) ||
161 (val->second->min_value == everything.min_value &&
162 val->second->max_value == everything.max_value))
163 << "Detected bound for " << expr << "conflicts with memorization";
164 }
165 (*bound_)[expr] = ConstIntBound(res.min_value, res.max_value);
166 }
167 return res;
168 }
169
170 Entry VisitExpr_(const RampNode* op) final {
171 // op = {base + i * stride | 0 <= i < lanes}
172 // Entry(op) = Union(Entry(base + i * stride) | 0 <= i < lanes)
173 // Note that `base + i * stride` is linear w.r.t. `i`
174 // Entry(op) = Union(Entry(base + i * stride) | i = 0, i = lanes-1)
175 Entry a = VisitExpr(op->base);
176 Entry b = VisitExpr(op->base + (op->lanes - 1) * op->stride);
177 return Union(a, b);
178 }
179
180 Entry VisitExpr_(const CastNode* op) final {
181 Entry a;
182
183 // int(ceil(log2(cast(n,"float64")))) is used as the
184 // implementation of topi.math.ceil_log2, and appears in iteration
185 // bounds.
186 if (auto opt = FindCeilLog2Arg(op)) {
187 a = CeilLog2Bounds(opt.value());
188 } else {
189 a = VisitExpr(op->value);
190 }
191
192 Entry b = Everything(op->dtype);
193 return Intersect(a, b);
194 }
195
196 Entry VisitExpr_(const IntImmNode* op) final { return MakeBound(op->value, op->value); }
197
198 Entry VisitExpr_(const AddNode* op) final {
199 Entry a = VisitExpr(op->a);
200 Entry b = VisitExpr(op->b);
201 Entry ret;
202 ret.min_value = InfAwareAdd(a.min_value, b.min_value);
203 ret.max_value = InfAwareAdd(a.max_value, b.max_value);
204 return ret;
205 }
206
207 Entry VisitExpr_(const SubNode* op) final {
208 Entry a = VisitExpr(op->a);
209 Entry b = VisitExpr(op->b);
210 Entry ret;
211 ret.min_value = InfAwareAdd(a.min_value, -b.max_value);
212 ret.max_value = InfAwareAdd(a.max_value, -b.min_value);
213 return ret;
214 }
215
216 Entry VisitExpr_(const MulNode* op) final {
217 Entry a = VisitExpr(op->a);
218 Entry b = VisitExpr(op->b);
219 return BinaryOpBoundary(a, b, InfAwareMul);
220 }
221
222 Entry VisitExpr_(const DivNode* op) final {
223 Entry a = VisitExpr(op->a);
224 Entry b = VisitExpr(op->b);
225 ICHECK(!b.is_const(0)) << "divide by zero";
226 return HandleDivision(a, b, op->dtype, InfAwareDiv);
227 }
228
229 Entry VisitExpr_(const ModNode* op) final {
230 Entry a = VisitExpr(op->a);
231 Entry b = VisitExpr(op->b);
232 if (b.min_value > 0) {
233 int64_t b_max_cap = InfAwareAdd(b.max_value, -1);
234 if (a.min_value >= 0) {
235 // 0 <= [a_min, a_max] < b_min
236 if (a.max_value < b.min_value) return a;
237 // other case, we can get close to 0
238 return MakeBound(0, std::min(a.max_value, b_max_cap));
239 } else {
240 return MakeBound(std::max(a.min_value, -b_max_cap),
241 std::min(std::max(a.max_value, (int64_t)0), b_max_cap));
242 }
243 } else {
244 ICHECK(!b.is_const(0)) << "mod by zero";
245 // mod by negative value is rare,
246 // and we just use the simpliest rule.
247 return Everything(op->dtype);
248 }
249 }
250
251 Entry VisitExpr_(const FloorDivNode* op) final {
252 Entry a = VisitExpr(op->a);
253 Entry b = VisitExpr(op->b);
254 ICHECK(!b.is_const(0)) << "floordiv by zero";
255 return HandleDivision(a, b, op->dtype, InfAwareFloorDiv);
256 }
257
258 Entry VisitExpr_(const FloorModNode* op) final {
259 /* let a / b = x + y, where x is integer, y \in [0, 1)
260 * floormod(a, b) = a - floordiv(a, b) * b
261 * floordiv(a, b) = x
262 * floormod(a, b) = a - floordiv(a, b) * b
263 * = a - x * b
264 * = a - (a / b - y) * b
265 * = a - a + y * b
266 * = y * b
267 * note that 0 <= y < 1
268 * when b > 0, 0 <= b * y < b
269 * 0 <= b * y <= b - 1
270 * when b < 0, b < b * y <= 0
271 * b + 1 <= b * y <= 0
272 * In all cases, min(0, b + 1) <= b * y <= max(0, b - 1)
273 * min(0, b_min + 1) <= b * y <= max(0, b_max - 1)
274 * That is, min(0, b_min + 1) <= floormod(a, b) <= max(0, b_max - 1)
275 */
276 Entry a = VisitExpr(op->a);
277 Entry b = VisitExpr(op->b);
278 if (b.min_value > 0) {
279 int64_t b_max_cap = InfAwareAdd(b.max_value, -1);
280 if (a.min_value >= 0) {
281 // 0 <= [a_min, a_max] < b_min
282 if (a.max_value < b.min_value) return a;
283 // other case, we can get close to 0
284 return MakeBound(0, std::min(a.max_value, b_max_cap));
285 } else {
286 return MakeBound(0, b_max_cap);
287 }
288 } else {
289 ICHECK(!b.is_const(0)) << "floormod by zero";
290 int64_t b_min_cap = InfAwareAdd(b.min_value, 1);
291 int64_t b_max_cap = InfAwareAdd(b.max_value, -1);
292 return Intersect(MakeBound(std::min(static_cast<int64_t>(0), b_min_cap),
293 std::max(static_cast<int64_t>(0), b_max_cap)),
294 Everything(op->dtype));
295 }
296 }
297
298 Entry VisitExpr_(const MinNode* op) final {
299 Entry a = VisitExpr(op->a);
300 Entry b = VisitExpr(op->b);
301 Entry ret;
302 ret.min_value = std::min(a.min_value, b.min_value);
303 ret.max_value = std::min(a.max_value, b.max_value);
304 return ret;
305 }
306
307 Entry VisitExpr_(const MaxNode* op) final {
308 Entry a = VisitExpr(op->a);
309 Entry b = VisitExpr(op->b);
310 Entry ret;
311 ret.min_value = std::max(a.min_value, b.min_value);
312 ret.max_value = std::max(a.max_value, b.max_value);
313 return ret;
314 }
315
316 Entry VisitExpr_(const SelectNode* op) final {
317 Entry a = VisitExpr(op->true_value);
318 Entry b = VisitExpr(op->false_value);
319 return Union(a, b);
320 }
321
322 Entry VisitExpr_(const CallNode* op) final {
323 // only special handle >> and & which can be
324 // used for index calculation.
325
326 if (op->op.same_as(tir::builtin::shift_right())) {
327 return VisitRightShift(op);
328 } else if (op->op.same_as(tir::builtin::shift_left())) {
329 return VisitLeftShift(op);
330 } else if (op->op.same_as(tir::builtin::bitwise_and())) {
331 return VisitBitwiseAnd(op);
332 } else {
333 return Everything(op->dtype);
334 }
335 }
336
337 Entry VisitExpr_(const VarNode* op) final {
338 Var v = GetRef<Var>(op);
339 auto it = var_map_.find(v);
340 if (it != var_map_.end()) {
341 return it->second;
342 } else {
343 return Everything(op->dtype);
344 }
345 }
346
347 Entry VisitExpr_(const SizeVarNode* op) final {
348 SizeVar v = GetRef<SizeVar>(op);
349 auto it = var_map_.find(v);
350 if (it != var_map_.end()) {
351 return it->second;
352 } else {
353 return MakeBound(0, kPosInf);
354 }
355 }
356
357 Entry VisitLeftShift(const CallNode* op) {
358 Entry a = VisitExpr(op->args[0]);
359 Entry b = VisitExpr(op->args[1]);
360
361 if (a.min_value < 0 || b.min_value < 0) {
362 // If either operand can negative, we may run into undefined
363 // behavior for some targets. In these cases, avoid making any
364 // assumptions about the result.
365 return Everything(op->dtype);
366 }
367
368 return BinaryOpBoundary(a, b, InfAwareLeftShift);
369 }
370
371 Entry VisitRightShift(const CallNode* op) {
372 Entry a = VisitExpr(op->args[0]);
373 Entry b = VisitExpr(op->args[1]);
374 return BinaryOpBoundary(a, b, InfAwareRightShift);
375 }
376
377 Entry VisitBitwiseAnd(const CallNode* op) {
378 Entry a = VisitExpr(op->args[0]);
379 Entry b = VisitExpr(op->args[1]);
380 // handle positive index case.
381 if (a.min_value >= 0 && b.min_value >= 0) {
382 return MakeBound(0, std::min(a.max_value, b.max_value));
383 } else {
384 if (b.min_value >= 0) {
385 return MakeBound(0, b.max_value);
386 }
387 if (a.min_value >= 0) {
388 return MakeBound(0, a.max_value);
389 }
390 return Everything(op->dtype);
391 }
392 }
393
394 std::function<void()> EnterConstraint(const PrimExpr& constraint) {
395 std::vector<BoundInfo> info = DetectBoundInfo(constraint);
396 if (info.size() == 0) return nullptr;
397 size_t old_size = additional_info_.size();
398 additional_info_.insert(additional_info_.end(), info.begin(), info.end());
399 size_t new_size = old_size + info.size();
400 auto frecover = [old_size, new_size, this]() {
401 ICHECK_EQ(additional_info_.size(), new_size);
402 additional_info_.resize(old_size);
403 };
404 return frecover;
405 }
406
407 private:
408 friend class ConstIntBoundAnalyzer;
409 // internal variable map
410 std::unordered_map<Var, Entry, ObjectPtrHash, ObjectPtrEqual> var_map_;
411 // additional bound info
412 std::vector<BoundInfo> additional_info_;
413 // look up table for memorization
414 BoundMapType* bound_{nullptr};
415 // constants: the limit value means umlimited
416 // NOTE: kNegInf/kPosInf are used to represent infinity.
417 static const constexpr int64_t kNegInf = ConstIntBound::kNegInf;
418 static const constexpr int64_t kPosInf = ConstIntBound::kPosInf;
419 static_assert(-kNegInf == kPosInf, "invariant of inf");
420 // internal helper functions
421 /*!
422 * \brief Get boundary of binary op who are monotonic wrt to one argument.
423 * \param a The entry of the left operand.
424 * \param b The entry of the right operand.
425 * \param op The operator.
426 * \tparam F the operator function type.
427 * \return The result.
428 */
429 template <typename F>
430 static Entry BinaryOpBoundary(Entry a, Entry b, const F& op) {
431 Entry ret;
432 // The boundary point must be shihft of the original boundary.
433 int64_t v1 = op(a.min_value, b.min_value);
434 int64_t v2 = op(a.max_value, b.max_value);
435 int64_t v3 = op(a.min_value, b.max_value);
436 int64_t v4 = op(a.max_value, b.min_value);
437 ret.min_value = std::min(std::min(std::min(v1, v2), v3), v4);
438 ret.max_value = std::max(std::max(std::max(v1, v2), v3), v4);
439 return ret;
440 }
441 /*!
442 * \brief Get value boundaries of division (e.g. Div or FloorDiv).
443 * \param a The entry of the left operand.
444 * \param b The entry of the right operand.
445 * \param dt The data type of the division operator.
446 * \param op The division operator.
447 * \tparam F the operator function type.
448 * \return The result.
449 */
450 template <typename F>
451 static Entry HandleDivision(Entry a, Entry b, DataType dt, const F& op) {
452 // Here we have a / b.
453 // The largest value of the division will be for the smallest (with
454 // respect to the absolute value) value of b. If the range of b starts
455 // at a negative value and ends at a positive one, narrow it down to
456 // be closer to 0, because BinaryOpBoundary only checks end-points of
457 // the domain ranges.
458
459 // If the range of b contains 0, then some infinity will be involved
460 if (b.min_value <= 0 && 0 <= b.max_value && dt.is_int()) {
461 Entry b_neg = b.min_value < 0 ? MakeBound(b.min_value, -1) : Everything(dt);
462 Entry b_pos = b.max_value > 0 ? MakeBound(1, b.max_value) : Everything(dt);
463
464 Entry e_neg = BinaryOpBoundary(a, b_neg, op);
465 Entry e_pos = BinaryOpBoundary(a, b_pos, op);
466
467 return MakeBound(std::min(e_neg.min_value, e_pos.min_value),
468 std::max(e_neg.max_value, e_pos.max_value));
469 } else if (b.min_value == 0 && dt.is_uint()) {
470 // uints only have one sided bounds
471 Entry assumed_b = MakeBound(1, b.max_value);
472 return BinaryOpBoundary(a, assumed_b, op);
473 }
474 // If the range of b does not have 0, use BinaryOpBoundary.
475 return BinaryOpBoundary(a, b, op);
476 }
477 /*!
478 * \brief Compute x + y, aware of inf.
479 * \param x The left operand.
480 * \param y The right operand.
481 * \return the result.
482 */
483 static int64_t InfAwareAdd(int64_t x, int64_t y) {
484 if (x == kPosInf) {
485 ICHECK(y != kNegInf);
486 return kPosInf;
487 }
488 if (x == kNegInf) {
489 ICHECK(y != kPosInf);
490 return kNegInf;
491 }
492 if (y == kPosInf || y == kNegInf) return y;
493 if (WillOverflow<AddNode>(x, y, kNegInf, kPosInf)) {
494 if (x > 0) return kPosInf;
495 return kNegInf;
496 }
497 return x + y;
498 }
499 /*!
500 * \brief Compute x * y, aware of inf.
501 * \param x The left operand.
502 * \param y The right operand.
503 * \return the result.
504 */
505 static int64_t InfAwareMul(int64_t x, int64_t y) {
506 if (!WillOverflow<MulNode>(x, y, kNegInf, kPosInf)) return x * y;
507 if ((x > 0 && y > 0) || (x < 0 && y < 0)) return kPosInf;
508 return kNegInf;
509 }
510 /*!
511 * \brief Compute x / y, aware of inf.
512 * \param x The left operand.
513 * \param y The right operand.
514 * \return the result.
515 */
516 static int64_t InfAwareDiv(int64_t x, int64_t y) {
517 ICHECK_NE(y, 0);
518 if (x == kPosInf || x == kNegInf) {
519 if (y > 0) return x;
520 return -x;
521 }
522 return x / y;
523 }
524 /*!
525 * \brief Compute floodiv(x, y), aware of inf.
526 * \param x The left operand.
527 * \param y The right operand.
528 * \return the result.
529 */
530 static int64_t InfAwareFloorDiv(int64_t x, int64_t y) {
531 ICHECK_NE(y, 0);
532 if (x == kPosInf || x == kNegInf) {
533 if (y > 0) return x;
534 return -x;
535 }
536 return floordiv(x, y);
537 }
538 /*!
539 * \brief Compute x << y, aware of inf.
540 * \param x The left operand.
541 * \param y The right operand.
542 * \return the result.
543 */
544 static int64_t InfAwareLeftShift(int64_t x, int64_t y) {
545 if (x == kPosInf || x == kNegInf) return x;
546
547 // Can be replaced with std::bit_width in C++20
548 auto bit_width = [](int64_t as_signed) {
549 uint64_t val = std::abs(as_signed);
550 int num_bits = 0;
551 while (val) {
552 ++num_bits;
553 val >>= 1;
554 }
555 return num_bits;
556 };
557 int x_bits = bit_width(x);
558 if (x_bits + y < 64) {
559 return x << y;
560 } else {
561 return kPosInf;
562 }
563 }
564 /*!
565 * \brief Compute x >> y, aware of inf.
566 * \param x The left operand.
567 * \param y The right operand.
568 * \return the result.
569 */
570 static int64_t InfAwareRightShift(int64_t x, int64_t y) {
571 if (x == kPosInf || x == kNegInf) return x;
572 return x >> y;
573 }
574 /*!
575 * \brief Make a new bound entry.
576 */
577 static Entry MakeBound(int64_t min_value, int64_t max_value) {
578 Entry e;
579 e.min_value = (min_value == kPosInf) ? min_value - 1 : min_value;
580 e.max_value = (max_value == kNegInf) ? max_value + 1 : max_value;
581 return e;
582 }
583 /*!
584 * \brief Create union of two sets.
585 * \param a The left operand.
586 * \param b the right operand.
587 */
588 static Entry Union(Entry a, Entry b) {
589 Entry ret;
590 ret.min_value = std::min(a.min_value, b.min_value);
591 ret.max_value = std::max(a.max_value, b.max_value);
592 return ret;
593 }
594 /*!
595 * \brief Create intersect of two sets.
596 * \param a The left operand.
597 * \param b the right operand.
598 */
599 static Entry Intersect(Entry a, Entry b) {
600 Entry ret;
601 ret.min_value = std::max(a.min_value, b.min_value);
602 ret.max_value = std::min(a.max_value, b.max_value);
603 return ret;
604 }
605 /*!
606 * \brief return everything dtype can represent.
607 * \param dtype The data type.
608 * \return Bound that represent everything dtype can represent.
609 */
610 static Entry Everything(DataType dtype) {
611 if (!dtype.is_int() && !dtype.is_uint()) {
612 return MakeBound(kNegInf, kPosInf);
613 }
614 Entry ret;
615 int64_t vbits = dtype.bits() - static_cast<int>(dtype.is_int());
616 if (dtype.is_uint()) {
617 ret.min_value = 0;
618 } else {
619 if (vbits >= 63) {
620 ret.min_value = kNegInf;
621 } else {
622 ret.min_value = -(static_cast<int64_t>(1) << vbits);
623 }
624 }
625 if (vbits >= 63) {
626 ret.max_value = kPosInf;
627 } else {
628 ret.max_value = (static_cast<int64_t>(1) << vbits) - 1;
629 }
630 return ret;
631 }
632
633 /*!
634 * \brief Detect additional constant bound from cond, if any
635 * \param cond The constraint condition.
636 * \return List of detected bounds.
637 */
638 static std::vector<BoundInfo> DetectBoundInfo(const PrimExpr& cond) {
639 PVar<PrimExpr> x, y;
640 PVar<IntImm> c;
641
642 std::vector<BoundInfo> info;
643 auto add_info = [&](const PrimExpr& expr, int64_t min_value, int64_t max_value) {
644 // If the conditional is comparing two integers, do not assign a
645 // value to them.
646 if (!expr->IsInstance<IntImmNode>()) {
647 info.push_back(BoundInfo(expr, MakeBound(min_value, max_value)));
648 }
649 };
650
651 for (const auto& subexpr : ExtractConstraints(cond)) {
652 // NOTE: The canonical form always uses <= or <, but a
653 // user-supplied constraint from the python API might not be
654 // canonicalized.
655 if ((c <= x).Match(subexpr) || (x >= c).Match(subexpr)) {
656 add_info(x.Eval(), c.Eval()->value, kPosInf);
657 } else if ((c < x).Match(subexpr) || (x > c).Match(subexpr)) {
658 add_info(x.Eval(), c.Eval()->value + 1, kPosInf);
659 } else if ((x <= c).Match(subexpr) || (x >= c).Match(subexpr)) {
660 add_info(x.Eval(), kNegInf, c.Eval()->value);
661 } else if ((x < c).Match(subexpr) || (c > x).Match(subexpr)) {
662 add_info(x.Eval(), kNegInf, c.Eval()->value - 1);
663 } else if ((x == c).Match(subexpr) || (c == x).Match(subexpr)) {
664 add_info(x.Eval(), c.Eval()->value, c.Eval()->value);
665 }
666 }
667
668 return info;
669 }
670
671 /*!
672 * \brief Extract the argument from int(ceil(log2(arg)))
673 *
674 * This expression is used as the implementation of
675 * topi.math.ceil_log2, and can appear in iteration bounds.
676 */
677 static Optional<PrimExpr> FindCeilLog2Arg(const CastNode* op) {
678 if (op->dtype.is_int()) {
679 if (auto as_call = op->value.as<CallNode>()) {
680 if (as_call->op.same_as(Op::Get("tir.ceil"))) {
681 PrimExpr ceil_arg = as_call->args[0];
682 if (auto arg_call = ceil_arg.as<CallNode>()) {
683 if (arg_call->op.same_as(Op::Get("tir.log2"))) {
684 PrimExpr log_arg = arg_call->args[0];
685 return log_arg;
686 }
687 }
688 }
689 }
690 }
691 return NullOpt;
692 }
693
694 /*! \brief Propagate constraints through ceil(log2(arg))
695 *
696 * Helper function for CastNode visitor
697 */
698 Entry CeilLog2Bounds(PrimExpr arg) {
699 if (auto as_float = arg.as<FloatImmNode>()) {
700 // A cast from int to float may have already been simplified
701 // out. Normally we don't inspect floating-point arguments, but here we can
702 int64_t val = std::ceil(std::log2(as_float->value));
703 return MakeBound(val, val);
704 } else {
705 Entry arg_bounds = VisitExpr(arg);
706 return MakeBound(std::ceil(std::log2(arg_bounds.min_value)),
707 std::ceil(std::log2(arg_bounds.max_value)));
708 }
709 }
710};
711
712ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr) const {
713 Entry ret = impl_->VisitExpr(expr);
714 return ConstIntBound(ret.min_value, ret.max_value);
715}
716
717ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr, BoundMapType* bound) {
718 impl_->bound_ = bound;
719 Entry ret = impl_->VisitExpr(expr);
720 impl_->bound_ = nullptr;
721 return ConstIntBound(ret.min_value, ret.max_value);
722}
723
724void ConstIntBoundAnalyzer::Update(const Var& var, const ConstIntBound& info, bool allow_override) {
725 impl_->Update(var, info, allow_override);
726}
727
728void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range, bool allow_override) {
729 impl_->Bind(var, range, allow_override);
730}
731
732std::function<void()> ConstIntBoundAnalyzer::EnterConstraint(const PrimExpr& constraint) {
733 return impl_->EnterConstraint(constraint);
734}
735
736ConstIntBoundAnalyzer::ConstIntBoundAnalyzer(Analyzer* parent) : impl_(new Impl()) {}
737
738ConstIntBoundAnalyzer::~ConstIntBoundAnalyzer() { delete impl_; }
739
740} // namespace arith
741} // namespace tvm
742