1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file int_set.cc
22 * \brief The integer set functions
23 */
24#include <tvm/arith/int_set.h>
25#include <tvm/arith/iter_affine_map.h>
26#include <tvm/runtime/registry.h>
27#include <tvm/tir/expr.h>
28#include <tvm/tir/expr_functor.h>
29
30#include <algorithm>
31#include <unordered_map>
32#include <utility>
33
34#include "constraint_extract.h"
35#include "interval_set.h"
36#include "pattern_match.h"
37
38namespace tvm {
39namespace arith {
40
41using tir::is_one;
42using tir::is_zero;
43using tir::make_const;
44using tir::make_zero;
45
46PrimExpr SymbolicLimits::pos_inf_ = Var("pos_inf", DataType::Handle());
47PrimExpr SymbolicLimits::neg_inf_ = Var("neg_inf", DataType::Handle());
48
49IntervalSet::IntervalSet(PrimExpr min_value, PrimExpr max_value) {
50 auto node = make_object<IntervalSetNode>();
51 node->min_value = std::move(min_value);
52 node->max_value = std::move(max_value);
53 data_ = std::move(node);
54}
55
56IntervalSet MakeIntervalSet(PrimExpr min_value, PrimExpr max_value) {
57 return IntervalSet(min_value, max_value);
58}
59
60TVM_REGISTER_GLOBAL("arith.IntervalSet").set_body_typed(MakeIntervalSet);
61
62IntervalSet Intersect(Analyzer* analyzer, IntervalSet a, IntervalSet b) {
63 PrimExpr max_value = min(a->max_value, b->max_value);
64 PrimExpr min_value = max(a->min_value, b->min_value);
65 if ((max_value.dtype().is_int() || max_value.dtype().is_uint()) &&
66 (min_value.dtype().is_int() || min_value.dtype().is_uint()) &&
67 analyzer->CanProve(max_value < min_value)) {
68 return IntervalSet::Empty();
69 } else {
70 return IntervalSet(min_value, max_value);
71 }
72}
73
74IntervalSet Union(Analyzer* analyzer, IntervalSet a, IntervalSet b) {
75 if (a->IsEmpty()) return b;
76 if (b->IsEmpty()) return a;
77 PrimExpr max_value = max(a->max_value, b->max_value);
78 PrimExpr min_value = min(a->min_value, b->min_value);
79 return IntervalSet(min_value, max_value);
80}
81
82// type traits
83template <typename OP>
84struct is_logical_op {
85 static const bool value = false;
86};
87
88#define TVM_DECLARE_LOGICAL_OP(OP) \
89 template <> \
90 struct is_logical_op<tir::OP> { \
91 static const bool value = true; \
92 };
93
94TVM_DECLARE_LOGICAL_OP(And);
95TVM_DECLARE_LOGICAL_OP(Or);
96TVM_DECLARE_LOGICAL_OP(EQ);
97TVM_DECLARE_LOGICAL_OP(NE);
98TVM_DECLARE_LOGICAL_OP(GE);
99TVM_DECLARE_LOGICAL_OP(GT);
100TVM_DECLARE_LOGICAL_OP(LE);
101TVM_DECLARE_LOGICAL_OP(LT);
102TVM_DECLARE_LOGICAL_OP(Not);
103
104/*!
105 * \brief Combine two interval set under arithmetic operations.
106 * \note this can possibly relax the set.
107 */
108template <typename Op>
109inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, DataType dtype) {
110 if (a->IsSinglePoint() && b->IsSinglePoint()) {
111 PrimExpr expr;
112 if (auto res = TryConstFold<Op>(a->min_value, b->min_value)) {
113 expr = res.value();
114 } else {
115 expr = Op(a->min_value, b->min_value);
116 }
117 return IntervalSet::SinglePoint(expr);
118 }
119 if (is_logical_op<Op>::value) {
120 return IntervalSet(make_const(dtype, 0), make_const(dtype, 1));
121 }
122 if (a->IsEmpty()) return a;
123 if (b->IsEmpty()) return b;
124 if (a->IsEverything()) return a;
125 if (b->IsEverything()) return b;
126 return IntervalSet::Everything();
127}
128
129template <>
130inline IntervalSet Combine<tir::Add>(Analyzer* analyer, IntervalSet a, IntervalSet b,
131 DataType /* dtype */) {
132 if (a->IsSinglePoint() && b->IsSinglePoint()) {
133 return IntervalSet::SinglePoint(a->min_value + b->min_value);
134 }
135 if (a->IsEmpty()) return a;
136 if (b->IsEmpty()) return b;
137 PrimExpr min_value =
138 a->HasLowerBound() && b->HasLowerBound() ? a->min_value + b->min_value : neg_inf();
139 PrimExpr max_value =
140 a->HasUpperBound() && b->HasUpperBound() ? a->max_value + b->max_value : pos_inf();
141 return IntervalSet(min_value, max_value);
142}
143
144template <>
145inline IntervalSet Combine<tir::Sub>(Analyzer* analyer, IntervalSet a, IntervalSet b,
146 DataType /* dtype */) {
147 if (a->IsSinglePoint() && b->IsSinglePoint()) {
148 return IntervalSet::SinglePoint(a->min_value - b->min_value);
149 }
150 if (a->IsEmpty()) return a;
151 if (b->IsEmpty()) return b;
152 PrimExpr min_value =
153 a->HasLowerBound() && b->HasUpperBound() ? a->min_value - b->max_value : neg_inf();
154 PrimExpr max_value =
155 a->HasUpperBound() && b->HasLowerBound() ? a->max_value - b->min_value : pos_inf();
156 return IntervalSet(min_value, max_value);
157}
158
159template <>
160inline IntervalSet Combine<tir::Mul>(Analyzer* analyzer, IntervalSet a, IntervalSet b,
161 DataType /* dtype */) {
162 if (a->IsSinglePoint() && b->IsSinglePoint()) {
163 return IntervalSet::SinglePoint(a->min_value * b->min_value);
164 }
165 if (a->IsEmpty()) return a;
166 if (b->IsEmpty()) return b;
167 if (a->IsSinglePoint()) {
168 std::swap(a, b);
169 }
170 if (b->IsSinglePoint()) {
171 if (is_zero(b->min_value)) return b;
172 if (is_one(b->min_value)) return a;
173 if (analyzer->CanProveGreaterEqual(b->min_value, 0)) {
174 PrimExpr min_value = a->HasLowerBound() ? a->min_value * b->min_value : neg_inf();
175 PrimExpr max_value = a->HasUpperBound() ? a->max_value * b->min_value : pos_inf();
176 return IntervalSet(min_value, max_value);
177 } else if (analyzer->CanProveGreaterEqual(-b->min_value, 1)) {
178 PrimExpr min_value = a->HasUpperBound() ? a->max_value * b->min_value : neg_inf();
179 PrimExpr max_value = a->HasLowerBound() ? a->min_value * b->min_value : pos_inf();
180 return IntervalSet(min_value, max_value);
181 } else if (a->HasUpperBound() && a->HasLowerBound()) {
182 using tir::Select;
183 PrimExpr sign = b->min_value >= make_zero(b->min_value.dtype().element_of());
184 PrimExpr e1 = a->min_value * b->min_value;
185 PrimExpr e2 = a->max_value * b->min_value;
186 return IntervalSet(Select(sign, e1, e2), Select(sign, e2, e1));
187 }
188 }
189 DLOG(WARNING) << "Return Everything in CombineInterval Mul";
190 return IntervalSet::Everything();
191}
192
193template <>
194inline IntervalSet Combine<tir::Div>(Analyzer* analyzer, IntervalSet a, IntervalSet b,
195 DataType /* dtype */) {
196 if (a->IsSinglePoint() && b->IsSinglePoint()) {
197 return IntervalSet::SinglePoint(a->min_value / b->min_value);
198 }
199 if (a->IsEmpty()) return a;
200 if (b->IsEmpty()) return b;
201 if (b->IsSinglePoint()) {
202 if (is_zero(b->min_value)) {
203 LOG(FATAL) << "Divide by zero in CombineInterval Div";
204 }
205 if (is_one(b->min_value)) return a;
206 // no relaxation is needed in here due to set is inclusive
207 if (analyzer->CanProveGreaterEqual(b->min_value, 0)) {
208 PrimExpr min_value = a->HasLowerBound() ? a->min_value / b->min_value : neg_inf();
209 PrimExpr max_value = a->HasUpperBound() ? a->max_value / b->min_value : pos_inf();
210 return IntervalSet(min_value, max_value);
211 } else if (analyzer->CanProveGreaterEqual(-b->min_value, 1)) {
212 PrimExpr min_value = a->HasUpperBound() ? a->max_value / b->min_value : neg_inf();
213 PrimExpr max_value = a->HasLowerBound() ? a->min_value / b->min_value : pos_inf();
214 return IntervalSet(min_value, max_value);
215 } else if (a->HasUpperBound() && a->HasLowerBound()) {
216 using tir::Select;
217 PrimExpr sign = b->min_value >= make_zero(b->min_value.dtype().element_of());
218 PrimExpr e1 = a->min_value / b->min_value;
219 PrimExpr e2 = a->max_value / b->min_value;
220 return IntervalSet(Select(sign, e1, e2), Select(sign, e2, e1));
221 }
222 }
223 DLOG(WARNING) << "Return Everything in CombineInterval Div";
224 return IntervalSet::Everything();
225}
226
227template <>
228inline IntervalSet Combine<tir::Mod>(Analyzer* analyzer, IntervalSet a, IntervalSet b,
229 DataType /* dtype */) {
230 if (a->IsSinglePoint() && b->IsSinglePoint()) {
231 return IntervalSet::SinglePoint(truncmod(a->min_value, b->min_value));
232 }
233 if (a->IsEmpty()) return a;
234 if (b->IsEmpty()) return b;
235
236 if (b->IsSinglePoint()) {
237 const PrimExpr& divisor = b->min_value;
238 if (is_zero(divisor)) {
239 LOG(FATAL) << "Modular by zero in CombineInterval Mod";
240 }
241 // We need to add more bound constraints throughout the code.
242 // The logic below assumes a is non-negative, which usually
243 // is the case of our application.
244 // TODO(tqchen): add bound constraints for a.
245 if (analyzer->CanProveGreaterEqual(divisor, 0)) {
246 return IntervalSet(make_zero(divisor.dtype()), divisor - 1);
247 } else {
248 PrimExpr bound = abs(divisor) - 1;
249 return IntervalSet(-bound, bound);
250 }
251 }
252 DLOG(WARNING) << "Return Everything in CombineInterval Mod";
253 return IntervalSet::Everything();
254}
255
256template <>
257inline IntervalSet Combine<tir::FloorDiv>(Analyzer* analyzer, IntervalSet a, IntervalSet b,
258 DataType /* dtype */) {
259 if (a->IsSinglePoint() && b->IsSinglePoint()) {
260 return IntervalSet::SinglePoint(floordiv(a->min_value, b->min_value));
261 }
262 if (a->IsEmpty()) return a;
263 if (b->IsEmpty()) return b;
264 if (b->IsSinglePoint()) {
265 if (is_zero(b->min_value)) {
266 LOG(FATAL) << "Divide by zero in CombineInterval Div";
267 }
268 if (is_one(b->min_value)) return a;
269 // no relaxation is needed in here due to set is inclusive
270 if (analyzer->CanProveGreaterEqual(b->min_value, 0)) {
271 PrimExpr min_value = a->HasLowerBound() ? floordiv(a->min_value, b->min_value) : neg_inf();
272 PrimExpr max_value = a->HasUpperBound() ? floordiv(a->max_value, b->min_value) : pos_inf();
273 return IntervalSet(min_value, max_value);
274 } else if (analyzer->CanProveGreaterEqual(-b->min_value, 1)) {
275 PrimExpr min_value = a->HasUpperBound() ? floordiv(a->max_value, b->min_value) : neg_inf();
276 PrimExpr max_value = a->HasLowerBound() ? floordiv(a->min_value, b->min_value) : pos_inf();
277 return IntervalSet(min_value, max_value);
278 } else if (a->HasUpperBound() && a->HasLowerBound()) {
279 using tir::Select;
280 PrimExpr sign = b->min_value >= make_zero(b->min_value.dtype().element_of());
281 PrimExpr e1 = floordiv(a->min_value, b->min_value);
282 PrimExpr e2 = floordiv(a->max_value, b->min_value);
283 return IntervalSet(Select(sign, e1, e2), Select(sign, e2, e1));
284 }
285 }
286 DLOG(WARNING) << "Return Everything in CombineInterval Div";
287 return IntervalSet::Everything();
288}
289
290template <>
291inline IntervalSet Combine<tir::FloorMod>(Analyzer* analyzer, IntervalSet a, IntervalSet b,
292 DataType /* dtype */) {
293 if (a->IsSinglePoint() && b->IsSinglePoint()) {
294 return IntervalSet::SinglePoint(floormod(a->min_value, b->min_value));
295 }
296 if (a->IsEmpty()) return a;
297 if (b->IsEmpty()) return b;
298
299 if (b->IsSinglePoint()) {
300 const PrimExpr& divisor = b->min_value;
301 if (is_zero(divisor)) {
302 LOG(FATAL) << "Modular by zero in CombineInterval Mod";
303 }
304 if (analyzer->CanProveGreaterEqual(divisor, 0)) {
305 if (divisor.as<tir::IntImmNode>()) {
306 // a mod b = a - (a / b) * b if a_max / b == a_min / b
307 auto qmax = a->HasUpperBound() ? floordiv(a->max_value, divisor) : pos_inf();
308 auto qmin = a->HasLowerBound() ? floordiv(a->min_value, divisor) : neg_inf();
309 // We can compare +/- inf against each other, but cannot use
310 // operator== between the symbolic limits and an integer.
311 bool compatible_dtypes = !(qmin.dtype().is_handle() ^ qmax.dtype().is_handle());
312 if (compatible_dtypes && analyzer->CanProve(qmax == qmin)) {
313 auto tmax = a->max_value - divisor * qmin;
314 auto tmin = a->min_value - divisor * qmin;
315 return IntervalSet(tmin, tmax);
316 }
317 }
318 return IntervalSet(make_zero(divisor.dtype()), divisor - 1);
319 } else {
320 PrimExpr bound = abs(divisor) - 1;
321 return IntervalSet(-bound, bound);
322 }
323 }
324 DLOG(WARNING) << "Return Everything in CombineInterval Mod";
325 return IntervalSet::Everything();
326}
327
328template <>
329inline IntervalSet Combine<tir::Max>(Analyzer* analzyer, IntervalSet a, IntervalSet b,
330 DataType /* dtype */) {
331 if (a->IsSinglePoint() && b->IsSinglePoint()) {
332 return IntervalSet::SinglePoint(max(a->min_value, b->min_value));
333 }
334 if (a->IsEmpty()) return a;
335 if (b->IsEmpty()) return b;
336 return IntervalSet(max(a->min_value, b->min_value), max(a->max_value, b->max_value));
337}
338
339template <>
340inline IntervalSet Combine<tir::Min>(Analyzer* analzyer, IntervalSet a, IntervalSet b,
341 DataType /* dtype */) {
342 if (a->IsSinglePoint() && b->IsSinglePoint()) {
343 return IntervalSet::SinglePoint(min(a->min_value, b->min_value));
344 }
345 if (a->IsEmpty()) return a;
346 if (b->IsEmpty()) return b;
347 return IntervalSet(min(a->min_value, b->min_value), min(a->max_value, b->max_value));
348}
349
350// internal helper function to get an interval set
351IntervalSet ToIntervalSet(IntSet set) {
352 if (auto* node = set.as<IntervalSetNode>()) {
353 return GetRef<IntervalSet>(node);
354 }
355 DLOG(INFO) << "cannot resolve int set " << set;
356 return IntervalSet::Everything();
357}
358
359using namespace tir;
360
361// Simplified version of int set evaluator that operates on IntervalSet
362// We might use better set analysis in the future to replace the intervalset.
363class IntervalSetEvaluator : public ExprFunctor<IntervalSet(const PrimExpr&)> {
364 public:
365 IntervalSetEvaluator(Analyzer* analyzer, const Map<Var, IntSet>& dom_map,
366 const std::vector<std::pair<Var, IntSet>>* dom_constraints = nullptr,
367 bool eval_vec = false)
368 : analyzer_(analyzer),
369 dom_map_(dom_map),
370 dom_constraints_(dom_constraints),
371 eval_vec_(eval_vec) {}
372
373 IntervalSet Eval(const PrimExpr& val) { return this->VisitExpr(val); }
374 // evaluate and relax the set
375 IntervalSet Eval(IntervalSet val) {
376 // avoid recursive indefinite recursive expansion.
377 if (static_cast<size_t>(recur_depth_) >= dom_map_.size()) return val;
378 ++recur_depth_;
379 IntervalSet min_set = this->Eval(val->min_value);
380 IntervalSet max_set = this->Eval(val->max_value);
381 --recur_depth_;
382 return IntervalSet(min_set->min_value, max_set->max_value);
383 }
384
385 IntervalSet VisitExpr_(const IntImmNode* op) final {
386 return IntervalSet::SinglePoint(GetRef<PrimExpr>(op));
387 }
388
389 IntervalSet VisitExpr_(const VarNode* op) final {
390 Var var = GetRef<Var>(op);
391
392 Array<IntSet> values;
393 if (dom_constraints_) {
394 for (const auto& constraint : *dom_constraints_) {
395 if (var.same_as(constraint.first)) {
396 values.push_back(constraint.second);
397 }
398 }
399 }
400
401 auto it = dom_map_.find(var);
402 if (it != dom_map_.end()) {
403 values.push_back((*it).second);
404 }
405
406 if (values.empty()) {
407 return IntervalSet::SinglePoint(var);
408 }
409
410 IntSet intersection = [&]() {
411 if (values.size() == 1) {
412 return values.front();
413 } else {
414 return Intersect(values);
415 }
416 }();
417
418 IntervalSet res = ToIntervalSet(intersection);
419 if (res->min_value.same_as(var) && res->max_value.same_as(var)) {
420 return res;
421 }
422 // recursively evaluate mapped result
423 // in case the domain contains variables to be relaxed.
424 return Eval(res);
425 }
426
427 IntervalSet VisitExpr_(const AddNode* op) final { return VisitBinaryExpr_<Add>(op); }
428
429 IntervalSet VisitExpr_(const SubNode* op) final { return VisitBinaryExpr_<Sub>(op); }
430
431 IntervalSet VisitExpr_(const MulNode* op) final { return VisitBinaryExpr_<Mul>(op); }
432
433 IntervalSet VisitExpr_(const DivNode* op) final { return VisitBinaryExpr_<Div>(op); }
434
435 IntervalSet VisitExpr_(const ModNode* op) final { return VisitBinaryExpr_<Mod>(op); }
436
437 IntervalSet VisitExpr_(const FloorDivNode* op) final { return VisitBinaryExpr_<FloorDiv>(op); }
438
439 IntervalSet VisitExpr_(const FloorModNode* op) final { return VisitBinaryExpr_<FloorMod>(op); }
440
441 IntervalSet VisitExpr_(const MinNode* op) final { return VisitBinaryExpr_<Min>(op); }
442
443 IntervalSet VisitExpr_(const MaxNode* op) final { return VisitBinaryExpr_<Max>(op); }
444
445 IntervalSet VisitExpr_(const EQNode* op) final { return VisitBinaryExpr_<EQ>(op); }
446
447 IntervalSet VisitExpr_(const NENode* op) final { return VisitBinaryExpr_<NE>(op); }
448
449 IntervalSet VisitExpr_(const LTNode* op) final { return VisitBinaryExpr_<LT>(op); }
450
451 IntervalSet VisitExpr_(const LENode* op) final { return VisitBinaryExpr_<LE>(op); }
452
453 IntervalSet VisitExpr_(const GTNode* op) final { return VisitBinaryExpr_<GT>(op); }
454
455 IntervalSet VisitExpr_(const GENode* op) final { return VisitBinaryExpr_<GE>(op); }
456
457 IntervalSet VisitExpr_(const AndNode* op) final { return VisitBinaryExpr_<And>(op); }
458
459 IntervalSet VisitExpr_(const OrNode* op) final { return VisitBinaryExpr_<Or>(op); }
460
461 IntervalSet VisitExpr_(const RampNode* op) final {
462 ICHECK(eval_vec_);
463 IntervalSet base = Eval(op->base);
464 PVar<IntImm> stride;
465 if (stride.Match(op->stride)) {
466 DataType t = op->base.dtype();
467 int64_t vstride = stride.Eval()->value;
468 if (vstride > 0) {
469 return Combine<Add>(analyzer_, base,
470 IntervalSet(make_zero(t), make_const(t, vstride * (op->lanes - 1))),
471 op->dtype);
472 } else {
473 return Combine<Add>(analyzer_, base,
474 IntervalSet(make_const(t, vstride * (op->lanes - 1)), make_zero(t)),
475 op->dtype);
476 }
477 }
478 DLOG(WARNING) << "cannot evaluate set on expression " << GetRef<PrimExpr>(op);
479 return IntervalSet::Everything();
480 }
481
482 IntervalSet VisitExpr_(const BroadcastNode* op) final {
483 ICHECK(eval_vec_);
484 return VisitExpr(op->value);
485 }
486
487 IntervalSet VisitExpr_(const SelectNode* op) final {
488 IntervalSet true_set = this->Eval(op->true_value);
489 IntervalSet false_set = this->Eval(op->false_value);
490 return Union(analyzer_, false_set, true_set);
491 }
492
493 IntervalSet VisitExpr_(const CastNode* op) final {
494 IntervalSet value_set = this->Eval(op->value);
495 PrimExpr min_value =
496 value_set->HasLowerBound() ? cast(op->dtype, value_set->min_value) : neg_inf();
497 PrimExpr max_value =
498 value_set->HasUpperBound() ? cast(op->dtype, value_set->max_value) : pos_inf();
499 return IntervalSet(min_value, max_value);
500 }
501
502 IntervalSet VisitExpr_(const BufferLoadNode* op) final {
503 if (!(op->dtype.is_int() || op->dtype.is_uint())) {
504 DLOG(WARNING) << "cannot evaluate set BufferLoad which loads from a " << op->dtype
505 << " buffer";
506 return IntervalSet::Everything();
507 }
508 // If the indices do not contain any variables to be relaxed, return the BufferLoad itself.
509 // Otherwise return `IntervalSet::everything()` since we have no knowledge on the buffer data.
510 for (const PrimExpr& index : op->indices) {
511 if (UsesVar(index, [dom_map = &this->dom_map_](const VarNode* var) {
512 return dom_map->find(GetRef<Var>(var)) != dom_map->end();
513 })) {
514 return IntervalSet::Everything();
515 }
516 }
517 return IntervalSet::SinglePoint(GetRef<PrimExpr>(op));
518 }
519
520 IntervalSet VisitExprDefault_(const Object* op) final {
521 DLOG(WARNING) << "cannot evaluate set type " << op->GetTypeKey();
522 return IntervalSet::Everything();
523 }
524
525 private:
526 // whether set is exactly single point that equals value.
527 bool MatchPoint(const IntervalSet& set, const PrimExpr& value) const {
528 return set->min_value.same_as(value) && set->max_value.same_as(value);
529 }
530
531 template <typename TOp, typename T>
532 inline IntervalSet VisitBinaryExpr_(const T* op) {
533 static_assert(std::is_same<typename TOp::ContainerType, T>::value, "constraint");
534 IntervalSet a = this->Eval(op->a);
535 IntervalSet b = this->Eval(op->b);
536 if (MatchPoint(a, op->a) && MatchPoint(b, op->b)) {
537 return IntervalSet::SinglePoint(GetRef<PrimExpr>(op));
538 }
539 return Combine<TOp>(analyzer_, a, b, op->dtype);
540 }
541
542 // recursive depth
543 int recur_depth_{0};
544 // analyzer
545 Analyzer* analyzer_;
546 const Map<Var, IntSet>& dom_map_;
547 const std::vector<std::pair<Var, IntSet>>* dom_constraints_;
548 bool eval_vec_{false};
549};
550
551class IntSetAnalyzer::Impl {
552 public:
553 explicit Impl(Analyzer* analyzer) : analyzer_(analyzer) {}
554
555 IntSet Eval(const PrimExpr& expr, const Map<Var, IntSet>& dom_map) const {
556 return IntervalSetEvaluator(analyzer_, dom_map).Eval(expr);
557 }
558
559 IntSet Eval(const PrimExpr& expr) const {
560 return IntervalSetEvaluator(analyzer_, dom_map_, &dom_constraints_, true).Eval(expr);
561 }
562
563 void Bind(const Var& var, const Range& range, bool allow_override) {
564 Update(var, IntSet::FromRange(range), allow_override);
565 }
566
567 void Update(const Var& var, const IntSet& info, bool override_info);
568 void Bind(const Var& var, const PrimExpr& expr, bool override_info);
569 std::function<void()> EnterConstraint(const PrimExpr& constraint);
570
571 private:
572 // Utility function to split a boolean condition into the domain
573 // bounds implied by that condition.
574 static std::vector<std::pair<Var, IntSet>> DetectBoundInfo(const PrimExpr& cond);
575
576 // The parent arith::Analyzer
577 Analyzer* analyzer_;
578
579 // Map of variables to global variable bounds (e.g. loop iterator
580 // ranges)
581 Map<Var, IntSet> dom_map_;
582
583 // List of implicit scope-dependent bounds (e.g. inside the body of
584 // an if-statement). Maintained as a list of constraints, rather
585 // than as a `Map<Var,IntSet>`, to avoid computing an Intersection
586 // until required.
587 std::vector<std::pair<Var, IntSet>> dom_constraints_;
588};
589
590IntSetAnalyzer::IntSetAnalyzer(Analyzer* parent) : impl_(new Impl(parent)) {}
591
592IntSetAnalyzer::~IntSetAnalyzer() { delete impl_; }
593
594IntSet IntSetAnalyzer::operator()(const PrimExpr& expr, const Map<Var, IntSet>& dom_map) {
595 return impl_->Eval(expr, dom_map);
596}
597
598IntSet IntSetAnalyzer::operator()(const PrimExpr& expr) { return impl_->Eval(expr); }
599
600void IntSetAnalyzer::Update(const Var& var, const IntSet& info, bool allow_override) {
601 impl_->Update(var, info, allow_override);
602}
603
604void IntSetAnalyzer::Bind(const Var& var, const Range& range, bool allow_override) {
605 impl_->Bind(var, range, allow_override);
606}
607
608void IntSetAnalyzer::Impl::Update(const Var& var, const IntSet& info, bool can_override) {
609 if (!can_override) {
610 auto it = dom_map_.find(var);
611 if (it != dom_map_.end()) {
612 const IntSet& old_info = (*it).second;
613
614 ICHECK(ExprDeepEqual()(old_info.min(), info.min()))
615 << "Trying to update var \'" << var << "\'"
616 << " with a different minimum value: "
617 << "original=" << old_info.min() << ", new=" << info.min();
618
619 ICHECK(ExprDeepEqual()(old_info.max(), info.max()))
620 << "Trying to update var \'" << var << "\'"
621 << " with a different maximum value: "
622 << "original=" << old_info.max() << ", new=" << info.max();
623 }
624 }
625 dom_map_.Set(var, info);
626}
627
628void IntSetAnalyzer::Impl::Bind(const Var& var, const PrimExpr& expr, bool can_override) {
629 Update(var, Eval(expr), can_override);
630}
631
632std::vector<std::pair<Var, IntSet>> IntSetAnalyzer::Impl::DetectBoundInfo(
633 const PrimExpr& constraint) {
634 PVar<Var> x;
635 PVar<PrimExpr> limit;
636
637 std::vector<std::pair<Var, IntSet>> bounds;
638 for (const PrimExpr& subconstraint : ExtractConstraints(constraint)) {
639 if ((x <= limit).Match(subconstraint)) {
640 bounds.push_back({x.Eval(), IntSet::Interval(SymbolicLimits::neg_inf_, limit.Eval())});
641 } else if ((x < limit).Match(subconstraint)) {
642 bounds.push_back({x.Eval(), IntSet::Interval(SymbolicLimits::neg_inf_, limit.Eval() - 1)});
643 } else if ((x >= limit).Match(subconstraint)) {
644 bounds.push_back({x.Eval(), IntSet::Interval(limit.Eval(), SymbolicLimits::pos_inf_)});
645 } else if ((x > limit).Match(subconstraint)) {
646 bounds.push_back({x.Eval(), IntSet::Interval(limit.Eval() + 1, SymbolicLimits::pos_inf_)});
647 } else if ((x == limit).Match(subconstraint)) {
648 bounds.push_back({x.Eval(), IntSet::SinglePoint(limit.Eval())});
649 }
650
651 if ((limit >= x).Match(subconstraint)) {
652 bounds.push_back({x.Eval(), IntSet::Interval(SymbolicLimits::neg_inf_, limit.Eval())});
653 } else if ((limit > x).Match(subconstraint)) {
654 bounds.push_back({x.Eval(), IntSet::Interval(SymbolicLimits::neg_inf_, limit.Eval() - 1)});
655 } else if ((limit <= x).Match(subconstraint)) {
656 bounds.push_back({x.Eval(), IntSet::Interval(limit.Eval(), SymbolicLimits::pos_inf_)});
657 } else if ((limit < x).Match(subconstraint)) {
658 bounds.push_back({x.Eval(), IntSet::Interval(limit.Eval() + 1, SymbolicLimits::pos_inf_)});
659 } else if ((limit == x).Match(subconstraint)) {
660 bounds.push_back({x.Eval(), IntSet::SinglePoint(limit.Eval())});
661 }
662 }
663 return bounds;
664}
665
666std::function<void()> IntSetAnalyzer::EnterConstraint(const PrimExpr& constraint) {
667 return impl_->EnterConstraint(constraint);
668}
669
670std::function<void()> IntSetAnalyzer::Impl::EnterConstraint(const PrimExpr& constraint) {
671 auto bounds = DetectBoundInfo(constraint);
672
673 if (bounds.size() == 0) return nullptr;
674
675 size_t old_size = dom_constraints_.size();
676 dom_constraints_.insert(dom_constraints_.end(), bounds.begin(), bounds.end());
677 size_t new_size = dom_constraints_.size();
678 auto frecover = [old_size, new_size, this]() {
679 ICHECK_EQ(dom_constraints_.size(), new_size);
680 dom_constraints_.resize(old_size);
681 };
682 return frecover;
683}
684
685// Quickly adapt to IntSet interface
686// TODO(tqchen): revisit IntSet interface as well.
687Range IntSet::CoverRange(Range max_range) const {
688 IntSet temp;
689 Analyzer analyzer;
690 const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
691 ICHECK(s_int != nullptr);
692 if (s_int->HasUpperBound() && s_int->HasLowerBound()) {
693 return Range::FromMinExtent(analyzer.Simplify(s_int->min_value),
694 analyzer.Simplify(s_int->max_value + 1 - s_int->min_value));
695 }
696 return max_range;
697}
698
699PrimExpr IntSet::min() const {
700 const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
701 ICHECK(s_int);
702 return s_int->min_value;
703}
704
705PrimExpr IntSet::max() const {
706 const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
707 ICHECK(s_int);
708 return s_int->max_value;
709}
710
711bool IntSet::IsNothing() const {
712 const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
713 return (s_int && s_int->IsEmpty());
714}
715
716bool IntSet::IsEverything() const {
717 const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
718 return (s_int && s_int->IsEverything());
719}
720
721bool IntSet::IsSinglePoint() const {
722 const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
723 return (s_int && s_int->IsSinglePoint());
724}
725
726bool IntSet::CanProvePositive() const {
727 Analyzer analyzer;
728 const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
729 return (s_int && is_positive_const(analyzer.Simplify(s_int->min_value)));
730}
731
732bool IntSet::CanProveNegative() const {
733 Analyzer analyzer;
734 const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
735 return (s_int && is_negative_const(analyzer.Simplify(s_int->max_value)));
736}
737
738bool IntSet::CanProveNonPositive() const {
739 Analyzer analyzer;
740 if (const auto* s_int = (*this).as<IntervalSetNode>()) {
741 auto max = analyzer.Simplify(s_int->max_value);
742 return is_zero(max) || is_negative_const(max);
743 }
744 return false;
745}
746
747bool IntSet::CanProveNonNegative() const {
748 Analyzer analyzer;
749 if (const IntervalSetNode* s_int = (*this).as<IntervalSetNode>()) {
750 auto min = analyzer.Simplify(s_int->min_value);
751 return is_zero(min) || is_positive_const(min);
752 }
753 return false;
754}
755
756bool IntSet::HasLowerBound() const {
757 if (const IntervalSetNode* s_int = (*this).as<IntervalSetNode>()) {
758 return s_int->HasLowerBound();
759 }
760 return false;
761}
762
763bool IntSet::HasUpperBound() const {
764 if (const IntervalSetNode* s_int = (*this).as<IntervalSetNode>()) {
765 return s_int->HasUpperBound();
766 }
767 return false;
768}
769
770SignType IntSet::GetSignType() const {
771 if (CanProvePositive()) {
772 return kPositive;
773 } else if (CanProveNegative()) {
774 return kNegative;
775 } else if (IsSinglePoint() && is_zero(PointValue())) {
776 return kZero;
777 } else {
778 return kUnknown;
779 }
780}
781PrimExpr IntSet::PointValue() const {
782 const IntervalSetNode* s_int = (*this).as<IntervalSetNode>();
783 ICHECK(s_int && s_int->IsSinglePoint());
784 return s_int->min_value;
785}
786
787IntSet IntSet::Nothing() { return IntervalSet::Empty(); }
788
789IntSet IntSet::Everything() { return IntervalSet::Everything(); }
790
791IntSet IntSet::SinglePoint(PrimExpr x) { return IntervalSet::SinglePoint(x); }
792
793IntSet IntSet::Interval(PrimExpr min, PrimExpr max) {
794 if (min.same_as(max)) {
795 return IntSet::SinglePoint(min);
796 }
797 return IntervalSet(min, max);
798}
799
800// Range related code
801inline bool ProveEqual(Analyzer* analyzer, PrimExpr lhs, PrimExpr rhs) {
802 return is_zero(analyzer->Simplify(lhs - rhs));
803}
804
805IntSet IntSet::FromMinExtent(PrimExpr min, PrimExpr extent) {
806 if (is_one(extent)) {
807 return IntSet::SinglePoint(min);
808 }
809 return IntervalSet(min, extent + min - 1);
810}
811
812IntSet IntSet::FromRange(Range r) {
813 // must make sure it can be matched back by MatchRange.
814 if (is_one(r->extent)) {
815 return IntSet::SinglePoint(r->min);
816 }
817 return IntervalSet(r->min, r->extent + r->min - 1);
818}
819
820bool IntSet::MatchRange(const Range& b) const {
821 const IntSet& a = *this;
822 const IntervalSetNode* a_int = a.as<IntervalSetNode>();
823 if (!a_int) return false;
824 if (!a_int->HasUpperBound() || !a_int->HasLowerBound()) return false;
825 Analyzer ana;
826 return ProveEqual(&ana, a_int->min_value, b->min) &&
827 ProveEqual(&ana, a_int->max_value, b->extent + b->min - 1);
828}
829
830IntSet Union(const Array<IntSet>& sets) {
831 if (sets.size() == 0) return IntSet::Nothing();
832 if (sets.size() == 1) return sets[0];
833 Analyzer ana;
834 IntervalSet x = ToIntervalSet(sets[0]);
835 for (size_t i = 1; i < sets.size(); ++i) {
836 x = Union(&ana, x, ToIntervalSet(sets[i]));
837 }
838 return IntervalSet(ana.Simplify(x->min_value), ana.Simplify(x->max_value));
839}
840
841Array<IntSet> UnionRegion(const Array<Array<IntSet>>& nd_int_sets) {
842 if (nd_int_sets.empty()) {
843 return {};
844 }
845 int n = nd_int_sets.size();
846 int ndim = nd_int_sets[0].size();
847 Array<IntSet> result;
848 result.reserve(ndim);
849 for (int i = 0; i < ndim; ++i) {
850 Array<IntSet> candidates;
851 candidates.reserve(n);
852 for (int j = 0; j < n; ++j) {
853 candidates.push_back(nd_int_sets[j][i]);
854 }
855 result.push_back(Union(candidates));
856 }
857 return result;
858}
859
860IntSet UnionLowerBound(const Array<IntSet>& sets) {
861 if (sets.size() == 0) return IntSet::Nothing();
862 if (sets.size() == 1) return sets[0];
863 Analyzer analyzer;
864 bool is_first_interval = true;
865 PrimExpr min_inclusive{nullptr};
866 PrimExpr max_inclusive(nullptr);
867 for (const IntSet& int_set : sets) {
868 if (const auto* interval_set = int_set.as<IntervalSetNode>()) {
869 PrimExpr new_min_inclusive = interval_set->min_value;
870 PrimExpr new_max_inclusive = interval_set->max_value;
871 if (is_first_interval) {
872 is_first_interval = false;
873 min_inclusive = std::move(new_min_inclusive);
874 max_inclusive = std::move(new_max_inclusive);
875 continue;
876 }
877 bool bound_1 = is_neg_inf(new_min_inclusive) || is_pos_inf(max_inclusive) ||
878 analyzer.CanProve(new_min_inclusive <= max_inclusive + 1);
879 bool bound_2 = is_neg_inf(min_inclusive) || is_pos_inf(new_max_inclusive) ||
880 analyzer.CanProve(min_inclusive <= new_max_inclusive + 1);
881 if (bound_1 && bound_2) {
882 min_inclusive = min(min_inclusive, new_min_inclusive);
883 max_inclusive = max(max_inclusive, new_max_inclusive);
884 }
885 }
886 }
887 if (is_first_interval) {
888 return IntSet::Nothing();
889 }
890 return IntSet::Interval(min_inclusive, max_inclusive);
891}
892
893Array<IntSet> UnionRegionLowerBound(const Array<Array<IntSet>>& nd_int_sets) {
894 if (nd_int_sets.empty()) {
895 return {};
896 }
897 int n = nd_int_sets.size();
898 int ndim = nd_int_sets[0].size();
899 Array<IntSet> result;
900 result.reserve(ndim);
901 for (int i = 0; i < ndim; ++i) {
902 Array<IntSet> candidates;
903 candidates.reserve(n);
904 for (int j = 0; j < n; ++j) {
905 candidates.push_back(nd_int_sets[j][i]);
906 }
907 result.push_back(UnionLowerBound(candidates));
908 }
909 return result;
910}
911
912IntSet Intersect(const Array<IntSet>& sets) {
913 if (sets.size() == 0) return IntSet::Nothing();
914 if (sets.size() == 1) return sets[0];
915 Analyzer ana;
916 IntervalSet x = ToIntervalSet(sets[0]);
917 for (size_t i = 1; i < sets.size(); ++i) {
918 x = Intersect(&ana, x, ToIntervalSet(sets[i]));
919 }
920 return IntervalSet(ana.Simplify(x->min_value), ana.Simplify(x->max_value));
921}
922
923Map<Var, IntSet> ConvertDomMap(const Map<IterVar, IntSet>& dom_map) {
924 Map<Var, IntSet> dmap;
925 for (auto kv : dom_map) {
926 dmap.Set(kv.first->var, kv.second);
927 }
928 return dmap;
929}
930
931Map<Var, IntSet> ConvertDomMap(const std::unordered_map<const VarNode*, IntSet>& dom_map) {
932 Map<Var, IntSet> dmap;
933 for (auto kv : dom_map) {
934 dmap.Set(GetRef<Var>(kv.first), kv.second);
935 }
936 return dmap;
937}
938
939IntSet EvalSet(PrimExpr e, const Map<Var, IntSet>& dom_map) {
940 Analyzer ana;
941 return IntervalSetEvaluator(&ana, dom_map, {}, false).Eval(e);
942}
943
944IntSet IntSet::Vector(PrimExpr x) {
945 Analyzer ana;
946 Map<Var, IntSet> dmap;
947 return IntervalSetEvaluator(&ana, dmap, {}, true).Eval(x);
948}
949
950IntSet EvalSet(PrimExpr e, const Map<IterVar, IntSet>& dom_map) {
951 return EvalSet(e, ConvertDomMap(dom_map));
952}
953
954IntSet EvalSet(PrimExpr e, const std::unordered_map<const VarNode*, IntSet>& dom_map) {
955 return EvalSet(e, ConvertDomMap(dom_map));
956}
957
958IntSet EvalSet(Range r, const Map<Var, IntSet>& dom_map) {
959 Analyzer ana;
960 if ((r->min->dtype.is_int() || r->min->dtype.is_uint()) && ana.CanProveEqual(r->extent, 1)) {
961 return EvalSet(r->min, dom_map);
962 }
963 IntervalSetEvaluator m(&ana, dom_map);
964 // Simplifying first can give tighter bounds if r->min and r->extent share variables
965 PrimExpr sum = r->min + r->extent - 1;
966 auto res = m.Eval(IntervalSet(r->min, ana.Simplify(sum)));
967 return std::move(res);
968}
969
970IntSet EvalSet(Range r, const std::unordered_map<const VarNode*, IntSet>& dom_map) {
971 return EvalSet(r, ConvertDomMap(dom_map));
972}
973
974Array<IntSet> EvalSet(const Array<Range>& region, const Map<Var, IntSet>& dom_map) {
975 Analyzer ana;
976 IntervalSetEvaluator m(&ana, dom_map);
977 Array<IntSet> result;
978 result.reserve(region.size());
979 for (const Range& r : region) {
980 PrimExpr sum = r->min + (r->extent - 1);
981 result.push_back(m.Eval(IntervalSet(r->min, ana.Simplify(sum))));
982 }
983 return result;
984}
985
986IntSet EvalSet(IntSet s, const std::unordered_map<const VarNode*, IntSet>& dom_map) {
987 Analyzer ana;
988 auto dmap = ConvertDomMap(dom_map);
989 IntervalSetEvaluator m(&ana, dmap);
990 const IntervalSetNode* s_int = s.as<IntervalSetNode>();
991 PrimExpr vmax = s_int->HasUpperBound() ? m.Eval(s_int->max_value).max() : s_int->max_value;
992 PrimExpr vmin = s_int->HasLowerBound() ? m.Eval(s_int->min_value).min() : s_int->min_value;
993 return IntervalSet(vmin, vmax);
994}
995
996class SubExprIntervalSetEvaluator : public IntervalSetEvaluator {
997 public:
998 explicit SubExprIntervalSetEvaluator(Analyzer* analyzer, const Map<Var, IntSet>& dom_map)
999 : IntervalSetEvaluator(analyzer, dom_map) {}
1000
1001 IntervalSet VisitExpr(const PrimExpr& n) final {
1002 IntervalSet ret = IntervalSetEvaluator::VisitExpr(n);
1003 expr_map[n] = ret;
1004 return ret;
1005 }
1006
1007 ExprIntSetMap expr_map;
1008};
1009
1010ExprIntSetMap EvalSetForEachSubExpr(PrimExpr e,
1011 const std::unordered_map<const VarNode*, IntSet>& dom_map) {
1012 Analyzer ana;
1013 auto dmap = ConvertDomMap(dom_map);
1014 SubExprIntervalSetEvaluator m(&ana, dmap);
1015 m.Eval(e);
1016 return m.expr_map;
1017}
1018
1019IntSet EvalSet(Range r, const Map<IterVar, IntSet>& dom_map) {
1020 return EvalSet(r, ConvertDomMap(dom_map));
1021}
1022
1023Map<Var, arith::IntSet> AsIntSet(const Map<Var, Range>& var_dom) {
1024 Map<Var, arith::IntSet> result;
1025 for (auto kv : var_dom) {
1026 const Var& var = kv.first;
1027 const Range& range = kv.second;
1028 result.Set(var, arith::IntSet::FromRange(range));
1029 }
1030 return result;
1031}
1032
1033/*! \brief Helper function to convert IterSumExpr to the actual touched range. */
1034static Optional<IntSet> EvalIterSum(const IterSumExpr& iter_min, const PrimExpr& extent,
1035 Analyzer* analyzer) {
1036 if (iter_min->args.empty()) {
1037 return IntSet::FromMinExtent(iter_min->base, extent);
1038 }
1039 ICHECK_EQ(iter_min->args.size(), 1) << "The `EvalIterSum` expects fused iter sum expr";
1040 const IterSplitExpr& split = iter_min->args[0];
1041 if (!analyzer->CanProve(extent >= split->scale)) {
1042 return NullOpt;
1043 }
1044
1045 const PrimExpr& base = iter_min->base;
1046 // IterSplitExpr: (source // lower_factor) % extent * scale
1047 // where `(source // lower_factor) % extent` is within [0, extent - 1]
1048 if (analyzer->CanProve(split->scale < 0)) {
1049 // If scale is negative, the var dom is [(extent - 1) * scale, 0]
1050 // The total base is `base + (extent - 1) * scale`,
1051 // while total extent is `dom_extent + (extent - 1) * (-scale)`
1052 const PrimExpr& var_extent = (split->extent - 1) * split->scale;
1053 return IntSet::FromMinExtent(base + var_extent, extent - var_extent);
1054 } else {
1055 // If scale is positive, the var dom is [0, (extent - 1) * scale]
1056 // The total dom is [base, dom_extent + (extent - 1) * scale]
1057 return IntSet::FromMinExtent(base, extent + (split->extent - 1) * split->scale);
1058 }
1059}
1060
1061Optional<Array<IntSet>> EstimateRegionStrictBound(const Array<Range>& region,
1062 const Map<Var, Range>& var_dom,
1063 const PrimExpr& predicate, Analyzer* analyzer) {
1064 int ndim = region.size();
1065 Array<IterSumExpr> iter_sum_exprs{nullptr};
1066 {
1067 Array<PrimExpr> affine_indices;
1068 affine_indices.reserve(ndim);
1069 for (const Range& range : region) {
1070 if (!is_const_number(range->extent)) {
1071 // dynamic extent is not supported yet.
1072 return NullOpt;
1073 }
1074 affine_indices.push_back(range->min);
1075 }
1076 auto res = DetectIterMap(
1077 /*indices=*/affine_indices, /*input_iters=*/var_dom,
1078 /*predicate=*/predicate, /*check_level=*/IterMapLevel::Surjective, analyzer);
1079 iter_sum_exprs = res->indices;
1080 }
1081 if (iter_sum_exprs.empty()) {
1082 return NullOpt;
1083 }
1084 ICHECK_EQ(iter_sum_exprs.size(), ndim);
1085 Array<IntSet> result;
1086 result.reserve(ndim);
1087 for (int i = 0; i < ndim; ++i) {
1088 const IterSumExpr& sum_expr = iter_sum_exprs[i];
1089 const Range& range = region[i];
1090 Optional<IntSet> int_set = EvalIterSum(sum_expr, range->extent, analyzer);
1091 if (int_set.defined()) {
1092 result.push_back(int_set.value());
1093 } else {
1094 return NullOpt;
1095 }
1096 }
1097 return result;
1098}
1099
1100Optional<Array<IntSet>> EstimateRegionLowerBound(const Array<Range>& region,
1101 const Map<Var, Range>& var_dom,
1102 const PrimExpr& predicate,
1103 arith::Analyzer* analyzer) {
1104 return EstimateRegionStrictBound(region, var_dom, predicate, analyzer);
1105}
1106
1107Array<IntSet> EstimateRegionUpperBound(const Array<Range>& region, const Map<Var, Range>& var_dom,
1108 const PrimExpr& predicate, Analyzer* analyzer) {
1109 if (Optional<Array<arith::IntSet>> result = EstimateRegionStrictBound(
1110 /*region=*/region,
1111 /*var_dom=*/var_dom,
1112 /*predicate=*/predicate, /*analyzer=*/analyzer)) {
1113 return result.value();
1114 }
1115 Array<IntSet> result;
1116 result.reserve(region.size());
1117 // try estimate each dimension independently
1118 for (const Range& range : region) {
1119 auto res = DetectIterMap(
1120 /*indices=*/{range->min}, /*input_iters=*/var_dom,
1121 /*predicate=*/predicate, /*check_level=*/IterMapLevel::Surjective, analyzer);
1122 if (!res->indices.empty()) {
1123 ICHECK_EQ(res->indices.size(), 1U);
1124 IterSumExpr sum_expr = res->indices[0];
1125
1126 // dynamic extent is not supported yet.
1127 PrimExpr extent = range->extent;
1128 if (!is_const_number(extent)) {
1129 IntSet relaxed = EvalSet(extent, AsIntSet(var_dom));
1130 ICHECK(relaxed.HasUpperBound());
1131 extent = relaxed.max();
1132 }
1133
1134 if (Optional<IntSet> int_set = EvalIterSum(sum_expr, range->extent, analyzer)) {
1135 result.push_back(int_set.value());
1136 continue;
1137 }
1138 }
1139 // fallback to coarse grained evalset
1140 result.push_back(EvalSet(range, AsIntSet(var_dom)));
1141 }
1142 return result;
1143}
1144
1145TVM_REGISTER_NODE_TYPE(IntervalSetNode);
1146
1147TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
1148 .set_dispatch<IntervalSetNode>([](const ObjectRef& node, ReprPrinter* p) {
1149 auto* op = static_cast<const IntervalSetNode*>(node.get());
1150 p->stream << "IntervalSet"
1151 << "[" << op->min_value << ", " << op->max_value << ']';
1152 });
1153
1154TVM_REGISTER_GLOBAL("arith.intset_single_point").set_body_typed(IntSet::SinglePoint);
1155
1156TVM_REGISTER_GLOBAL("arith.intset_vector").set_body_typed(IntSet::Vector);
1157
1158TVM_REGISTER_GLOBAL("arith.intset_interval").set_body_typed(IntSet::Interval);
1159
1160TVM_REGISTER_GLOBAL("arith.IntervalSetGetMin").set_body_method(&IntSet::min);
1161
1162TVM_REGISTER_GLOBAL("arith.IntervalSetGetMax").set_body_method(&IntSet::max);
1163
1164TVM_REGISTER_GLOBAL("arith.IntSetIsNothing").set_body_method(&IntSet::IsNothing);
1165
1166TVM_REGISTER_GLOBAL("arith.IntSetIsEverything").set_body_method(&IntSet::IsEverything);
1167
1168TVM_REGISTER_GLOBAL("arith.EstimateRegionLowerBound")
1169 .set_body_typed([](Array<Range> region, Map<Var, Range> var_dom,
1170 PrimExpr predicate) -> Optional<Array<IntSet>> {
1171 Analyzer analyzer;
1172 return EstimateRegionLowerBound(region, var_dom, predicate, &analyzer);
1173 });
1174TVM_REGISTER_GLOBAL("arith.EstimateRegionStrictBound")
1175 .set_body_typed([](Array<Range> region, Map<Var, Range> var_dom,
1176 PrimExpr predicate) -> Optional<Array<IntSet>> {
1177 Analyzer analyzer;
1178 return EstimateRegionStrictBound(region, var_dom, predicate, &analyzer);
1179 });
1180TVM_REGISTER_GLOBAL("arith.EstimateRegionUpperBound")
1181 .set_body_typed([](Array<Range> region, Map<Var, Range> var_dom,
1182 PrimExpr predicate) -> Optional<Array<IntSet>> {
1183 Analyzer analyzer;
1184 return EstimateRegionUpperBound(region, var_dom, predicate, &analyzer);
1185 });
1186
1187TVM_REGISTER_GLOBAL("arith.PosInf").set_body_typed([]() { return SymbolicLimits::pos_inf_; });
1188TVM_REGISTER_GLOBAL("arith.NegInf").set_body_typed([]() { return SymbolicLimits::neg_inf_; });
1189TVM_REGISTER_GLOBAL("arith.UnionLowerBound").set_body_typed(UnionLowerBound);
1190
1191} // namespace arith
1192} // namespace tvm
1193