1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file int_set.cc |
22 | * \brief The integer set functions |
23 | */ |
24 | #include <tvm/arith/int_set.h> |
25 | #include <tvm/arith/iter_affine_map.h> |
26 | #include <tvm/runtime/registry.h> |
27 | #include <tvm/tir/expr.h> |
28 | #include <tvm/tir/expr_functor.h> |
29 | |
30 | #include <algorithm> |
31 | #include <unordered_map> |
32 | #include <utility> |
33 | |
34 | #include "constraint_extract.h" |
35 | #include "interval_set.h" |
36 | #include "pattern_match.h" |
37 | |
38 | namespace tvm { |
39 | namespace arith { |
40 | |
41 | using tir::is_one; |
42 | using tir::is_zero; |
43 | using tir::make_const; |
44 | using tir::make_zero; |
45 | |
46 | PrimExpr SymbolicLimits::pos_inf_ = Var("pos_inf" , DataType::Handle()); |
47 | PrimExpr SymbolicLimits::neg_inf_ = Var("neg_inf" , DataType::Handle()); |
48 | |
49 | IntervalSet::IntervalSet(PrimExpr min_value, PrimExpr max_value) { |
50 | auto node = make_object<IntervalSetNode>(); |
51 | node->min_value = std::move(min_value); |
52 | node->max_value = std::move(max_value); |
53 | data_ = std::move(node); |
54 | } |
55 | |
56 | IntervalSet MakeIntervalSet(PrimExpr min_value, PrimExpr max_value) { |
57 | return IntervalSet(min_value, max_value); |
58 | } |
59 | |
60 | TVM_REGISTER_GLOBAL("arith.IntervalSet" ).set_body_typed(MakeIntervalSet); |
61 | |
62 | IntervalSet Intersect(Analyzer* analyzer, IntervalSet a, IntervalSet b) { |
63 | PrimExpr max_value = min(a->max_value, b->max_value); |
64 | PrimExpr min_value = max(a->min_value, b->min_value); |
65 | if ((max_value.dtype().is_int() || max_value.dtype().is_uint()) && |
66 | (min_value.dtype().is_int() || min_value.dtype().is_uint()) && |
67 | analyzer->CanProve(max_value < min_value)) { |
68 | return IntervalSet::Empty(); |
69 | } else { |
70 | return IntervalSet(min_value, max_value); |
71 | } |
72 | } |
73 | |
74 | IntervalSet Union(Analyzer* analyzer, IntervalSet a, IntervalSet b) { |
75 | if (a->IsEmpty()) return b; |
76 | if (b->IsEmpty()) return a; |
77 | PrimExpr max_value = max(a->max_value, b->max_value); |
78 | PrimExpr min_value = min(a->min_value, b->min_value); |
79 | return IntervalSet(min_value, max_value); |
80 | } |
81 | |
82 | // type traits |
83 | template <typename OP> |
84 | struct is_logical_op { |
85 | static const bool value = false; |
86 | }; |
87 | |
88 | #define TVM_DECLARE_LOGICAL_OP(OP) \ |
89 | template <> \ |
90 | struct is_logical_op<tir::OP> { \ |
91 | static const bool value = true; \ |
92 | }; |
93 | |
94 | TVM_DECLARE_LOGICAL_OP(And); |
95 | TVM_DECLARE_LOGICAL_OP(Or); |
96 | TVM_DECLARE_LOGICAL_OP(EQ); |
97 | TVM_DECLARE_LOGICAL_OP(NE); |
98 | TVM_DECLARE_LOGICAL_OP(GE); |
99 | TVM_DECLARE_LOGICAL_OP(GT); |
100 | TVM_DECLARE_LOGICAL_OP(LE); |
101 | TVM_DECLARE_LOGICAL_OP(LT); |
102 | TVM_DECLARE_LOGICAL_OP(Not); |
103 | |
104 | /*! |
105 | * \brief Combine two interval set under arithmetic operations. |
106 | * \note this can possibly relax the set. |
107 | */ |
108 | template <typename Op> |
109 | inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, DataType dtype) { |
110 | if (a->IsSinglePoint() && b->IsSinglePoint()) { |
111 | PrimExpr expr; |
112 | if (auto res = TryConstFold<Op>(a->min_value, b->min_value)) { |
113 | expr = res.value(); |
114 | } else { |
115 | expr = Op(a->min_value, b->min_value); |
116 | } |
117 | return IntervalSet::SinglePoint(expr); |
118 | } |
119 | if (is_logical_op<Op>::value) { |
120 | return IntervalSet(make_const(dtype, 0), make_const(dtype, 1)); |
121 | } |
122 | if (a->IsEmpty()) return a; |
123 | if (b->IsEmpty()) return b; |
124 | if (a->IsEverything()) return a; |
125 | if (b->IsEverything()) return b; |
126 | return IntervalSet::Everything(); |
127 | } |
128 | |
129 | template <> |
130 | inline IntervalSet Combine<tir::Add>(Analyzer* analyer, IntervalSet a, IntervalSet b, |
131 | DataType /* dtype */) { |
132 | if (a->IsSinglePoint() && b->IsSinglePoint()) { |
133 | return IntervalSet::SinglePoint(a->min_value + b->min_value); |
134 | } |
135 | if (a->IsEmpty()) return a; |
136 | if (b->IsEmpty()) return b; |
137 | PrimExpr min_value = |
138 | a->HasLowerBound() && b->HasLowerBound() ? a->min_value + b->min_value : neg_inf(); |
139 | PrimExpr max_value = |
140 | a->HasUpperBound() && b->HasUpperBound() ? a->max_value + b->max_value : pos_inf(); |
141 | return IntervalSet(min_value, max_value); |
142 | } |
143 | |
144 | template <> |
145 | inline IntervalSet Combine<tir::Sub>(Analyzer* analyer, IntervalSet a, IntervalSet b, |
146 | DataType /* dtype */) { |
147 | if (a->IsSinglePoint() && b->IsSinglePoint()) { |
148 | return IntervalSet::SinglePoint(a->min_value - b->min_value); |
149 | } |
150 | if (a->IsEmpty()) return a; |
151 | if (b->IsEmpty()) return b; |
152 | PrimExpr min_value = |
153 | a->HasLowerBound() && b->HasUpperBound() ? a->min_value - b->max_value : neg_inf(); |
154 | PrimExpr max_value = |
155 | a->HasUpperBound() && b->HasLowerBound() ? a->max_value - b->min_value : pos_inf(); |
156 | return IntervalSet(min_value, max_value); |
157 | } |
158 | |
159 | template <> |
160 | inline IntervalSet Combine<tir::Mul>(Analyzer* analyzer, IntervalSet a, IntervalSet b, |
161 | DataType /* dtype */) { |
162 | if (a->IsSinglePoint() && b->IsSinglePoint()) { |
163 | return IntervalSet::SinglePoint(a->min_value * b->min_value); |
164 | } |
165 | if (a->IsEmpty()) return a; |
166 | if (b->IsEmpty()) return b; |
167 | if (a->IsSinglePoint()) { |
168 | std::swap(a, b); |
169 | } |
170 | if (b->IsSinglePoint()) { |
171 | if (is_zero(b->min_value)) return b; |
172 | if (is_one(b->min_value)) return a; |
173 | if (analyzer->CanProveGreaterEqual(b->min_value, 0)) { |
174 | PrimExpr min_value = a->HasLowerBound() ? a->min_value * b->min_value : neg_inf(); |
175 | PrimExpr max_value = a->HasUpperBound() ? a->max_value * b->min_value : pos_inf(); |
176 | return IntervalSet(min_value, max_value); |
177 | } else if (analyzer->CanProveGreaterEqual(-b->min_value, 1)) { |
178 | PrimExpr min_value = a->HasUpperBound() ? a->max_value * b->min_value : neg_inf(); |
179 | PrimExpr max_value = a->HasLowerBound() ? a->min_value * b->min_value : pos_inf(); |
180 | return IntervalSet(min_value, max_value); |
181 | } else if (a->HasUpperBound() && a->HasLowerBound()) { |
182 | using tir::Select; |
183 | PrimExpr sign = b->min_value >= make_zero(b->min_value.dtype().element_of()); |
184 | PrimExpr e1 = a->min_value * b->min_value; |
185 | PrimExpr e2 = a->max_value * b->min_value; |
186 | return IntervalSet(Select(sign, e1, e2), Select(sign, e2, e1)); |
187 | } |
188 | } |
189 | DLOG(WARNING) << "Return Everything in CombineInterval Mul" ; |
190 | return IntervalSet::Everything(); |
191 | } |
192 | |
193 | template <> |
194 | inline IntervalSet Combine<tir::Div>(Analyzer* analyzer, IntervalSet a, IntervalSet b, |
195 | DataType /* dtype */) { |
196 | if (a->IsSinglePoint() && b->IsSinglePoint()) { |
197 | return IntervalSet::SinglePoint(a->min_value / b->min_value); |
198 | } |
199 | if (a->IsEmpty()) return a; |
200 | if (b->IsEmpty()) return b; |
201 | if (b->IsSinglePoint()) { |
202 | if (is_zero(b->min_value)) { |
203 | LOG(FATAL) << "Divide by zero in CombineInterval Div" ; |
204 | } |
205 | if (is_one(b->min_value)) return a; |
206 | // no relaxation is needed in here due to set is inclusive |
207 | if (analyzer->CanProveGreaterEqual(b->min_value, 0)) { |
208 | PrimExpr min_value = a->HasLowerBound() ? a->min_value / b->min_value : neg_inf(); |
209 | PrimExpr max_value = a->HasUpperBound() ? a->max_value / b->min_value : pos_inf(); |
210 | return IntervalSet(min_value, max_value); |
211 | } else if (analyzer->CanProveGreaterEqual(-b->min_value, 1)) { |
212 | PrimExpr min_value = a->HasUpperBound() ? a->max_value / b->min_value : neg_inf(); |
213 | PrimExpr max_value = a->HasLowerBound() ? a->min_value / b->min_value : pos_inf(); |
214 | return IntervalSet(min_value, max_value); |
215 | } else if (a->HasUpperBound() && a->HasLowerBound()) { |
216 | using tir::Select; |
217 | PrimExpr sign = b->min_value >= make_zero(b->min_value.dtype().element_of()); |
218 | PrimExpr e1 = a->min_value / b->min_value; |
219 | PrimExpr e2 = a->max_value / b->min_value; |
220 | return IntervalSet(Select(sign, e1, e2), Select(sign, e2, e1)); |
221 | } |
222 | } |
223 | DLOG(WARNING) << "Return Everything in CombineInterval Div" ; |
224 | return IntervalSet::Everything(); |
225 | } |
226 | |
227 | template <> |
228 | inline IntervalSet Combine<tir::Mod>(Analyzer* analyzer, IntervalSet a, IntervalSet b, |
229 | DataType /* dtype */) { |
230 | if (a->IsSinglePoint() && b->IsSinglePoint()) { |
231 | return IntervalSet::SinglePoint(truncmod(a->min_value, b->min_value)); |
232 | } |
233 | if (a->IsEmpty()) return a; |
234 | if (b->IsEmpty()) return b; |
235 | |
236 | if (b->IsSinglePoint()) { |
237 | const PrimExpr& divisor = b->min_value; |
238 | if (is_zero(divisor)) { |
239 | LOG(FATAL) << "Modular by zero in CombineInterval Mod" ; |
240 | } |
241 | // We need to add more bound constraints throughout the code. |
242 | // The logic below assumes a is non-negative, which usually |
243 | // is the case of our application. |
244 | // TODO(tqchen): add bound constraints for a. |
245 | if (analyzer->CanProveGreaterEqual(divisor, 0)) { |
246 | return IntervalSet(make_zero(divisor.dtype()), divisor - 1); |
247 | } else { |
248 | PrimExpr bound = abs(divisor) - 1; |
249 | return IntervalSet(-bound, bound); |
250 | } |
251 | } |
252 | DLOG(WARNING) << "Return Everything in CombineInterval Mod" ; |
253 | return IntervalSet::Everything(); |
254 | } |
255 | |
256 | template <> |
257 | inline IntervalSet Combine<tir::FloorDiv>(Analyzer* analyzer, IntervalSet a, IntervalSet b, |
258 | DataType /* dtype */) { |
259 | if (a->IsSinglePoint() && b->IsSinglePoint()) { |
260 | return IntervalSet::SinglePoint(floordiv(a->min_value, b->min_value)); |
261 | } |
262 | if (a->IsEmpty()) return a; |
263 | if (b->IsEmpty()) return b; |
264 | if (b->IsSinglePoint()) { |
265 | if (is_zero(b->min_value)) { |
266 | LOG(FATAL) << "Divide by zero in CombineInterval Div" ; |
267 | } |
268 | if (is_one(b->min_value)) return a; |
269 | // no relaxation is needed in here due to set is inclusive |
270 | if (analyzer->CanProveGreaterEqual(b->min_value, 0)) { |
271 | PrimExpr min_value = a->HasLowerBound() ? floordiv(a->min_value, b->min_value) : neg_inf(); |
272 | PrimExpr max_value = a->HasUpperBound() ? floordiv(a->max_value, b->min_value) : pos_inf(); |
273 | return IntervalSet(min_value, max_value); |
274 | } else if (analyzer->CanProveGreaterEqual(-b->min_value, 1)) { |
275 | PrimExpr min_value = a->HasUpperBound() ? floordiv(a->max_value, b->min_value) : neg_inf(); |
276 | PrimExpr max_value = a->HasLowerBound() ? floordiv(a->min_value, b->min_value) : pos_inf(); |
277 | return IntervalSet(min_value, max_value); |
278 | } else if (a->HasUpperBound() && a->HasLowerBound()) { |
279 | using tir::Select; |
280 | PrimExpr sign = b->min_value >= make_zero(b->min_value.dtype().element_of()); |
281 | PrimExpr e1 = floordiv(a->min_value, b->min_value); |
282 | PrimExpr e2 = floordiv(a->max_value, b->min_value); |
283 | return IntervalSet(Select(sign, e1, e2), Select(sign, e2, e1)); |
284 | } |
285 | } |
286 | DLOG(WARNING) << "Return Everything in CombineInterval Div" ; |
287 | return IntervalSet::Everything(); |
288 | } |
289 | |
290 | template <> |
291 | inline IntervalSet Combine<tir::FloorMod>(Analyzer* analyzer, IntervalSet a, IntervalSet b, |
292 | DataType /* dtype */) { |
293 | if (a->IsSinglePoint() && b->IsSinglePoint()) { |
294 | return IntervalSet::SinglePoint(floormod(a->min_value, b->min_value)); |
295 | } |
296 | if (a->IsEmpty()) return a; |
297 | if (b->IsEmpty()) return b; |
298 | |
299 | if (b->IsSinglePoint()) { |
300 | const PrimExpr& divisor = b->min_value; |
301 | if (is_zero(divisor)) { |
302 | LOG(FATAL) << "Modular by zero in CombineInterval Mod" ; |
303 | } |
304 | if (analyzer->CanProveGreaterEqual(divisor, 0)) { |
305 | if (divisor.as<tir::IntImmNode>()) { |
306 | // a mod b = a - (a / b) * b if a_max / b == a_min / b |
307 | auto qmax = a->HasUpperBound() ? floordiv(a->max_value, divisor) : pos_inf(); |
308 | auto qmin = a->HasLowerBound() ? floordiv(a->min_value, divisor) : neg_inf(); |
309 | // We can compare +/- inf against each other, but cannot use |
310 | // operator== between the symbolic limits and an integer. |
311 | bool compatible_dtypes = !(qmin.dtype().is_handle() ^ qmax.dtype().is_handle()); |
312 | if (compatible_dtypes && analyzer->CanProve(qmax == qmin)) { |
313 | auto tmax = a->max_value - divisor * qmin; |
314 | auto tmin = a->min_value - divisor * qmin; |
315 | return IntervalSet(tmin, tmax); |
316 | } |
317 | } |
318 | return IntervalSet(make_zero(divisor.dtype()), divisor - 1); |
319 | } else { |
320 | PrimExpr bound = abs(divisor) - 1; |
321 | return IntervalSet(-bound, bound); |
322 | } |
323 | } |
324 | DLOG(WARNING) << "Return Everything in CombineInterval Mod" ; |
325 | return IntervalSet::Everything(); |
326 | } |
327 | |
328 | template <> |
329 | inline IntervalSet Combine<tir::Max>(Analyzer* analzyer, IntervalSet a, IntervalSet b, |
330 | DataType /* dtype */) { |
331 | if (a->IsSinglePoint() && b->IsSinglePoint()) { |
332 | return IntervalSet::SinglePoint(max(a->min_value, b->min_value)); |
333 | } |
334 | if (a->IsEmpty()) return a; |
335 | if (b->IsEmpty()) return b; |
336 | return IntervalSet(max(a->min_value, b->min_value), max(a->max_value, b->max_value)); |
337 | } |
338 | |
339 | template <> |
340 | inline IntervalSet Combine<tir::Min>(Analyzer* analzyer, IntervalSet a, IntervalSet b, |
341 | DataType /* dtype */) { |
342 | if (a->IsSinglePoint() && b->IsSinglePoint()) { |
343 | return IntervalSet::SinglePoint(min(a->min_value, b->min_value)); |
344 | } |
345 | if (a->IsEmpty()) return a; |
346 | if (b->IsEmpty()) return b; |
347 | return IntervalSet(min(a->min_value, b->min_value), min(a->max_value, b->max_value)); |
348 | } |
349 | |
350 | // internal helper function to get an interval set |
351 | IntervalSet ToIntervalSet(IntSet set) { |
352 | if (auto* node = set.as<IntervalSetNode>()) { |
353 | return GetRef<IntervalSet>(node); |
354 | } |
355 | DLOG(INFO) << "cannot resolve int set " << set; |
356 | return IntervalSet::Everything(); |
357 | } |
358 | |
359 | using namespace tir; |
360 | |
361 | // Simplified version of int set evaluator that operates on IntervalSet |
362 | // We might use better set analysis in the future to replace the intervalset. |
363 | class IntervalSetEvaluator : public ExprFunctor<IntervalSet(const PrimExpr&)> { |
364 | public: |
365 | IntervalSetEvaluator(Analyzer* analyzer, const Map<Var, IntSet>& dom_map, |
366 | const std::vector<std::pair<Var, IntSet>>* dom_constraints = nullptr, |
367 | bool eval_vec = false) |
368 | : analyzer_(analyzer), |
369 | dom_map_(dom_map), |
370 | dom_constraints_(dom_constraints), |
371 | eval_vec_(eval_vec) {} |
372 | |
373 | IntervalSet Eval(const PrimExpr& val) { return this->VisitExpr(val); } |
374 | // evaluate and relax the set |
375 | IntervalSet Eval(IntervalSet val) { |
376 | // avoid recursive indefinite recursive expansion. |
377 | if (static_cast<size_t>(recur_depth_) >= dom_map_.size()) return val; |
378 | ++recur_depth_; |
379 | IntervalSet min_set = this->Eval(val->min_value); |
380 | IntervalSet max_set = this->Eval(val->max_value); |
381 | --recur_depth_; |
382 | return IntervalSet(min_set->min_value, max_set->max_value); |
383 | } |
384 | |
385 | IntervalSet VisitExpr_(const IntImmNode* op) final { |
386 | return IntervalSet::SinglePoint(GetRef<PrimExpr>(op)); |
387 | } |
388 | |
389 | IntervalSet VisitExpr_(const VarNode* op) final { |
390 | Var var = GetRef<Var>(op); |
391 | |
392 | Array<IntSet> values; |
393 | if (dom_constraints_) { |
394 | for (const auto& constraint : *dom_constraints_) { |
395 | if (var.same_as(constraint.first)) { |
396 | values.push_back(constraint.second); |
397 | } |
398 | } |
399 | } |
400 | |
401 | auto it = dom_map_.find(var); |
402 | if (it != dom_map_.end()) { |
403 | values.push_back((*it).second); |
404 | } |
405 | |
406 | if (values.empty()) { |
407 | return IntervalSet::SinglePoint(var); |
408 | } |
409 | |
410 | IntSet intersection = [&]() { |
411 | if (values.size() == 1) { |
412 | return values.front(); |
413 | } else { |
414 | return Intersect(values); |
415 | } |
416 | }(); |
417 | |
418 | IntervalSet res = ToIntervalSet(intersection); |
419 | if (res->min_value.same_as(var) && res->max_value.same_as(var)) { |
420 | return res; |
421 | } |
422 | // recursively evaluate mapped result |
423 | // in case the domain contains variables to be relaxed. |
424 | return Eval(res); |
425 | } |
426 | |
427 | IntervalSet VisitExpr_(const AddNode* op) final { return VisitBinaryExpr_<Add>(op); } |
428 | |
429 | IntervalSet VisitExpr_(const SubNode* op) final { return VisitBinaryExpr_<Sub>(op); } |
430 | |
431 | IntervalSet VisitExpr_(const MulNode* op) final { return VisitBinaryExpr_<Mul>(op); } |
432 | |
433 | IntervalSet VisitExpr_(const DivNode* op) final { return VisitBinaryExpr_<Div>(op); } |
434 | |
435 | IntervalSet VisitExpr_(const ModNode* op) final { return VisitBinaryExpr_<Mod>(op); } |
436 | |
437 | IntervalSet VisitExpr_(const FloorDivNode* op) final { return VisitBinaryExpr_<FloorDiv>(op); } |
438 | |
439 | IntervalSet VisitExpr_(const FloorModNode* op) final { return VisitBinaryExpr_<FloorMod>(op); } |
440 | |
441 | IntervalSet VisitExpr_(const MinNode* op) final { return VisitBinaryExpr_<Min>(op); } |
442 | |
443 | IntervalSet VisitExpr_(const MaxNode* op) final { return VisitBinaryExpr_<Max>(op); } |
444 | |
445 | IntervalSet VisitExpr_(const EQNode* op) final { return VisitBinaryExpr_<EQ>(op); } |
446 | |
447 | IntervalSet VisitExpr_(const NENode* op) final { return VisitBinaryExpr_<NE>(op); } |
448 | |
449 | IntervalSet VisitExpr_(const LTNode* op) final { return VisitBinaryExpr_<LT>(op); } |
450 | |
451 | IntervalSet VisitExpr_(const LENode* op) final { return VisitBinaryExpr_<LE>(op); } |
452 | |
453 | IntervalSet VisitExpr_(const GTNode* op) final { return VisitBinaryExpr_<GT>(op); } |
454 | |
455 | IntervalSet VisitExpr_(const GENode* op) final { return VisitBinaryExpr_<GE>(op); } |
456 | |
457 | IntervalSet VisitExpr_(const AndNode* op) final { return VisitBinaryExpr_<And>(op); } |
458 | |
459 | IntervalSet VisitExpr_(const OrNode* op) final { return VisitBinaryExpr_<Or>(op); } |
460 | |
461 | IntervalSet VisitExpr_(const RampNode* op) final { |
462 | ICHECK(eval_vec_); |
463 | IntervalSet base = Eval(op->base); |
464 | PVar<IntImm> stride; |
465 | if (stride.Match(op->stride)) { |
466 | DataType t = op->base.dtype(); |
467 | int64_t vstride = stride.Eval()->value; |
468 | if (vstride > 0) { |
469 | return Combine<Add>(analyzer_, base, |
470 | IntervalSet(make_zero(t), make_const(t, vstride * (op->lanes - 1))), |
471 | op->dtype); |
472 | } else { |
473 | return Combine<Add>(analyzer_, base, |
474 | IntervalSet(make_const(t, vstride * (op->lanes - 1)), make_zero(t)), |
475 | op->dtype); |
476 | } |
477 | } |
478 | DLOG(WARNING) << "cannot evaluate set on expression " << GetRef<PrimExpr>(op); |
479 | return IntervalSet::Everything(); |
480 | } |
481 | |
482 | IntervalSet VisitExpr_(const BroadcastNode* op) final { |
483 | ICHECK(eval_vec_); |
484 | return VisitExpr(op->value); |
485 | } |
486 | |
487 | IntervalSet VisitExpr_(const SelectNode* op) final { |
488 | IntervalSet true_set = this->Eval(op->true_value); |
489 | IntervalSet false_set = this->Eval(op->false_value); |
490 | return Union(analyzer_, false_set, true_set); |
491 | } |
492 | |
493 | IntervalSet VisitExpr_(const CastNode* op) final { |
494 | IntervalSet value_set = this->Eval(op->value); |
495 | PrimExpr min_value = |
496 | value_set->HasLowerBound() ? cast(op->dtype, value_set->min_value) : neg_inf(); |
497 | PrimExpr max_value = |
498 | value_set->HasUpperBound() ? cast(op->dtype, value_set->max_value) : pos_inf(); |
499 | return IntervalSet(min_value, max_value); |
500 | } |
501 | |
502 | IntervalSet VisitExpr_(const BufferLoadNode* op) final { |
503 | if (!(op->dtype.is_int() || op->dtype.is_uint())) { |
504 | DLOG(WARNING) << "cannot evaluate set BufferLoad which loads from a " << op->dtype |
505 | << " buffer" ; |
506 | return IntervalSet::Everything(); |
507 | } |
508 | // If the indices do not contain any variables to be relaxed, return the BufferLoad itself. |
509 | // Otherwise return `IntervalSet::everything()` since we have no knowledge on the buffer data. |
510 | for (const PrimExpr& index : op->indices) { |
511 | if (UsesVar(index, [dom_map = &this->dom_map_](const VarNode* var) { |
512 | return dom_map->find(GetRef<Var>(var)) != dom_map->end(); |
513 | })) { |
514 | return IntervalSet::Everything(); |
515 | } |
516 | } |
517 | return IntervalSet::SinglePoint(GetRef<PrimExpr>(op)); |
518 | } |
519 | |
520 | IntervalSet VisitExprDefault_(const Object* op) final { |
521 | DLOG(WARNING) << "cannot evaluate set type " << op->GetTypeKey(); |
522 | return IntervalSet::Everything(); |
523 | } |
524 | |
525 | private: |
526 | // whether set is exactly single point that equals value. |
527 | bool MatchPoint(const IntervalSet& set, const PrimExpr& value) const { |
528 | return set->min_value.same_as(value) && set->max_value.same_as(value); |
529 | } |
530 | |
531 | template <typename TOp, typename T> |
532 | inline IntervalSet VisitBinaryExpr_(const T* op) { |
533 | static_assert(std::is_same<typename TOp::ContainerType, T>::value, "constraint" ); |
534 | IntervalSet a = this->Eval(op->a); |
535 | IntervalSet b = this->Eval(op->b); |
536 | if (MatchPoint(a, op->a) && MatchPoint(b, op->b)) { |
537 | return IntervalSet::SinglePoint(GetRef<PrimExpr>(op)); |
538 | } |
539 | return Combine<TOp>(analyzer_, a, b, op->dtype); |
540 | } |
541 | |
542 | // recursive depth |
543 | int recur_depth_{0}; |
544 | // analyzer |
545 | Analyzer* analyzer_; |
546 | const Map<Var, IntSet>& dom_map_; |
547 | const std::vector<std::pair<Var, IntSet>>* dom_constraints_; |
548 | bool eval_vec_{false}; |
549 | }; |
550 | |
551 | class IntSetAnalyzer::Impl { |
552 | public: |
553 | explicit Impl(Analyzer* analyzer) : analyzer_(analyzer) {} |
554 | |
555 | IntSet Eval(const PrimExpr& expr, const Map<Var, IntSet>& dom_map) const { |
556 | return IntervalSetEvaluator(analyzer_, dom_map).Eval(expr); |
557 | } |
558 | |
559 | IntSet Eval(const PrimExpr& expr) const { |
560 | return IntervalSetEvaluator(analyzer_, dom_map_, &dom_constraints_, true).Eval(expr); |
561 | } |
562 | |
563 | void Bind(const Var& var, const Range& range, bool allow_override) { |
564 | Update(var, IntSet::FromRange(range), allow_override); |
565 | } |
566 | |
567 | void Update(const Var& var, const IntSet& info, bool override_info); |
568 | void Bind(const Var& var, const PrimExpr& expr, bool override_info); |
569 | std::function<void()> EnterConstraint(const PrimExpr& constraint); |
570 | |
571 | private: |
572 | // Utility function to split a boolean condition into the domain |
573 | // bounds implied by that condition. |
574 | static std::vector<std::pair<Var, IntSet>> DetectBoundInfo(const PrimExpr& cond); |
575 | |
576 | // The parent arith::Analyzer |
577 | Analyzer* analyzer_; |
578 | |
579 | // Map of variables to global variable bounds (e.g. loop iterator |
580 | // ranges) |
581 | Map<Var, IntSet> dom_map_; |
582 | |
583 | // List of implicit scope-dependent bounds (e.g. inside the body of |
584 | // an if-statement). Maintained as a list of constraints, rather |
585 | // than as a `Map<Var,IntSet>`, to avoid computing an Intersection |
586 | // until required. |
587 | std::vector<std::pair<Var, IntSet>> dom_constraints_; |
588 | }; |
589 | |
590 | IntSetAnalyzer::IntSetAnalyzer(Analyzer* parent) : impl_(new Impl(parent)) {} |
591 | |
592 | IntSetAnalyzer::~IntSetAnalyzer() { delete impl_; } |
593 | |
594 | IntSet IntSetAnalyzer::operator()(const PrimExpr& expr, const Map<Var, IntSet>& dom_map) { |
595 | return impl_->Eval(expr, dom_map); |
596 | } |
597 | |
598 | IntSet IntSetAnalyzer::operator()(const PrimExpr& expr) { return impl_->Eval(expr); } |
599 | |
600 | void IntSetAnalyzer::Update(const Var& var, const IntSet& info, bool allow_override) { |
601 | impl_->Update(var, info, allow_override); |
602 | } |
603 | |
604 | void IntSetAnalyzer::Bind(const Var& var, const Range& range, bool allow_override) { |
605 | impl_->Bind(var, range, allow_override); |
606 | } |
607 | |
608 | void IntSetAnalyzer::Impl::Update(const Var& var, const IntSet& info, bool can_override) { |
609 | if (!can_override) { |
610 | auto it = dom_map_.find(var); |
611 | if (it != dom_map_.end()) { |
612 | const IntSet& old_info = (*it).second; |
613 | |
614 | ICHECK(ExprDeepEqual()(old_info.min(), info.min())) |
615 | << "Trying to update var \'" << var << "\'" |
616 | << " with a different minimum value: " |
617 | << "original=" << old_info.min() << ", new=" << info.min(); |
618 | |
619 | ICHECK(ExprDeepEqual()(old_info.max(), info.max())) |
620 | << "Trying to update var \'" << var << "\'" |
621 | << " with a different maximum value: " |
622 | << "original=" << old_info.max() << ", new=" << info.max(); |
623 | } |
624 | } |
625 | dom_map_.Set(var, info); |
626 | } |
627 | |
628 | void IntSetAnalyzer::Impl::Bind(const Var& var, const PrimExpr& expr, bool can_override) { |
629 | Update(var, Eval(expr), can_override); |
630 | } |
631 | |
632 | std::vector<std::pair<Var, IntSet>> IntSetAnalyzer::Impl::DetectBoundInfo( |
633 | const PrimExpr& constraint) { |
634 | PVar<Var> x; |
635 | PVar<PrimExpr> limit; |
636 | |
637 | std::vector<std::pair<Var, IntSet>> bounds; |
638 | for (const PrimExpr& subconstraint : ExtractConstraints(constraint)) { |
639 | if ((x <= limit).Match(subconstraint)) { |
640 | bounds.push_back({x.Eval(), IntSet::Interval(SymbolicLimits::neg_inf_, limit.Eval())}); |
641 | } else if ((x < limit).Match(subconstraint)) { |
642 | bounds.push_back({x.Eval(), IntSet::Interval(SymbolicLimits::neg_inf_, limit.Eval() - 1)}); |
643 | } else if ((x >= limit).Match(subconstraint)) { |
644 | bounds.push_back({x.Eval(), IntSet::Interval(limit.Eval(), SymbolicLimits::pos_inf_)}); |
645 | } else if ((x > limit).Match(subconstraint)) { |
646 | bounds.push_back({x.Eval(), IntSet::Interval(limit.Eval() + 1, SymbolicLimits::pos_inf_)}); |
647 | } else if ((x == limit).Match(subconstraint)) { |
648 | bounds.push_back({x.Eval(), IntSet::SinglePoint(limit.Eval())}); |
649 | } |
650 | |
651 | if ((limit >= x).Match(subconstraint)) { |
652 | bounds.push_back({x.Eval(), IntSet::Interval(SymbolicLimits::neg_inf_, limit.Eval())}); |
653 | } else if ((limit > x).Match(subconstraint)) { |
654 | bounds.push_back({x.Eval(), IntSet::Interval(SymbolicLimits::neg_inf_, limit.Eval() - 1)}); |
655 | } else if ((limit <= x).Match(subconstraint)) { |
656 | bounds.push_back({x.Eval(), IntSet::Interval(limit.Eval(), SymbolicLimits::pos_inf_)}); |
657 | } else if ((limit < x).Match(subconstraint)) { |
658 | bounds.push_back({x.Eval(), IntSet::Interval(limit.Eval() + 1, SymbolicLimits::pos_inf_)}); |
659 | } else if ((limit == x).Match(subconstraint)) { |
660 | bounds.push_back({x.Eval(), IntSet::SinglePoint(limit.Eval())}); |
661 | } |
662 | } |
663 | return bounds; |
664 | } |
665 | |
666 | std::function<void()> IntSetAnalyzer::EnterConstraint(const PrimExpr& constraint) { |
667 | return impl_->EnterConstraint(constraint); |
668 | } |
669 | |
670 | std::function<void()> IntSetAnalyzer::Impl::EnterConstraint(const PrimExpr& constraint) { |
671 | auto bounds = DetectBoundInfo(constraint); |
672 | |
673 | if (bounds.size() == 0) return nullptr; |
674 | |
675 | size_t old_size = dom_constraints_.size(); |
676 | dom_constraints_.insert(dom_constraints_.end(), bounds.begin(), bounds.end()); |
677 | size_t new_size = dom_constraints_.size(); |
678 | auto frecover = [old_size, new_size, this]() { |
679 | ICHECK_EQ(dom_constraints_.size(), new_size); |
680 | dom_constraints_.resize(old_size); |
681 | }; |
682 | return frecover; |
683 | } |
684 | |
685 | // Quickly adapt to IntSet interface |
686 | // TODO(tqchen): revisit IntSet interface as well. |
687 | Range IntSet::CoverRange(Range max_range) const { |
688 | IntSet temp; |
689 | Analyzer analyzer; |
690 | const IntervalSetNode* s_int = (*this).as<IntervalSetNode>(); |
691 | ICHECK(s_int != nullptr); |
692 | if (s_int->HasUpperBound() && s_int->HasLowerBound()) { |
693 | return Range::FromMinExtent(analyzer.Simplify(s_int->min_value), |
694 | analyzer.Simplify(s_int->max_value + 1 - s_int->min_value)); |
695 | } |
696 | return max_range; |
697 | } |
698 | |
699 | PrimExpr IntSet::min() const { |
700 | const IntervalSetNode* s_int = (*this).as<IntervalSetNode>(); |
701 | ICHECK(s_int); |
702 | return s_int->min_value; |
703 | } |
704 | |
705 | PrimExpr IntSet::max() const { |
706 | const IntervalSetNode* s_int = (*this).as<IntervalSetNode>(); |
707 | ICHECK(s_int); |
708 | return s_int->max_value; |
709 | } |
710 | |
711 | bool IntSet::IsNothing() const { |
712 | const IntervalSetNode* s_int = (*this).as<IntervalSetNode>(); |
713 | return (s_int && s_int->IsEmpty()); |
714 | } |
715 | |
716 | bool IntSet::IsEverything() const { |
717 | const IntervalSetNode* s_int = (*this).as<IntervalSetNode>(); |
718 | return (s_int && s_int->IsEverything()); |
719 | } |
720 | |
721 | bool IntSet::IsSinglePoint() const { |
722 | const IntervalSetNode* s_int = (*this).as<IntervalSetNode>(); |
723 | return (s_int && s_int->IsSinglePoint()); |
724 | } |
725 | |
726 | bool IntSet::CanProvePositive() const { |
727 | Analyzer analyzer; |
728 | const IntervalSetNode* s_int = (*this).as<IntervalSetNode>(); |
729 | return (s_int && is_positive_const(analyzer.Simplify(s_int->min_value))); |
730 | } |
731 | |
732 | bool IntSet::CanProveNegative() const { |
733 | Analyzer analyzer; |
734 | const IntervalSetNode* s_int = (*this).as<IntervalSetNode>(); |
735 | return (s_int && is_negative_const(analyzer.Simplify(s_int->max_value))); |
736 | } |
737 | |
738 | bool IntSet::CanProveNonPositive() const { |
739 | Analyzer analyzer; |
740 | if (const auto* s_int = (*this).as<IntervalSetNode>()) { |
741 | auto max = analyzer.Simplify(s_int->max_value); |
742 | return is_zero(max) || is_negative_const(max); |
743 | } |
744 | return false; |
745 | } |
746 | |
747 | bool IntSet::CanProveNonNegative() const { |
748 | Analyzer analyzer; |
749 | if (const IntervalSetNode* s_int = (*this).as<IntervalSetNode>()) { |
750 | auto min = analyzer.Simplify(s_int->min_value); |
751 | return is_zero(min) || is_positive_const(min); |
752 | } |
753 | return false; |
754 | } |
755 | |
756 | bool IntSet::HasLowerBound() const { |
757 | if (const IntervalSetNode* s_int = (*this).as<IntervalSetNode>()) { |
758 | return s_int->HasLowerBound(); |
759 | } |
760 | return false; |
761 | } |
762 | |
763 | bool IntSet::HasUpperBound() const { |
764 | if (const IntervalSetNode* s_int = (*this).as<IntervalSetNode>()) { |
765 | return s_int->HasUpperBound(); |
766 | } |
767 | return false; |
768 | } |
769 | |
770 | SignType IntSet::GetSignType() const { |
771 | if (CanProvePositive()) { |
772 | return kPositive; |
773 | } else if (CanProveNegative()) { |
774 | return kNegative; |
775 | } else if (IsSinglePoint() && is_zero(PointValue())) { |
776 | return kZero; |
777 | } else { |
778 | return kUnknown; |
779 | } |
780 | } |
781 | PrimExpr IntSet::PointValue() const { |
782 | const IntervalSetNode* s_int = (*this).as<IntervalSetNode>(); |
783 | ICHECK(s_int && s_int->IsSinglePoint()); |
784 | return s_int->min_value; |
785 | } |
786 | |
787 | IntSet IntSet::Nothing() { return IntervalSet::Empty(); } |
788 | |
789 | IntSet IntSet::Everything() { return IntervalSet::Everything(); } |
790 | |
791 | IntSet IntSet::SinglePoint(PrimExpr x) { return IntervalSet::SinglePoint(x); } |
792 | |
793 | IntSet IntSet::Interval(PrimExpr min, PrimExpr max) { |
794 | if (min.same_as(max)) { |
795 | return IntSet::SinglePoint(min); |
796 | } |
797 | return IntervalSet(min, max); |
798 | } |
799 | |
800 | // Range related code |
801 | inline bool ProveEqual(Analyzer* analyzer, PrimExpr lhs, PrimExpr rhs) { |
802 | return is_zero(analyzer->Simplify(lhs - rhs)); |
803 | } |
804 | |
805 | IntSet IntSet::FromMinExtent(PrimExpr min, PrimExpr extent) { |
806 | if (is_one(extent)) { |
807 | return IntSet::SinglePoint(min); |
808 | } |
809 | return IntervalSet(min, extent + min - 1); |
810 | } |
811 | |
812 | IntSet IntSet::FromRange(Range r) { |
813 | // must make sure it can be matched back by MatchRange. |
814 | if (is_one(r->extent)) { |
815 | return IntSet::SinglePoint(r->min); |
816 | } |
817 | return IntervalSet(r->min, r->extent + r->min - 1); |
818 | } |
819 | |
820 | bool IntSet::MatchRange(const Range& b) const { |
821 | const IntSet& a = *this; |
822 | const IntervalSetNode* a_int = a.as<IntervalSetNode>(); |
823 | if (!a_int) return false; |
824 | if (!a_int->HasUpperBound() || !a_int->HasLowerBound()) return false; |
825 | Analyzer ana; |
826 | return ProveEqual(&ana, a_int->min_value, b->min) && |
827 | ProveEqual(&ana, a_int->max_value, b->extent + b->min - 1); |
828 | } |
829 | |
830 | IntSet Union(const Array<IntSet>& sets) { |
831 | if (sets.size() == 0) return IntSet::Nothing(); |
832 | if (sets.size() == 1) return sets[0]; |
833 | Analyzer ana; |
834 | IntervalSet x = ToIntervalSet(sets[0]); |
835 | for (size_t i = 1; i < sets.size(); ++i) { |
836 | x = Union(&ana, x, ToIntervalSet(sets[i])); |
837 | } |
838 | return IntervalSet(ana.Simplify(x->min_value), ana.Simplify(x->max_value)); |
839 | } |
840 | |
841 | Array<IntSet> UnionRegion(const Array<Array<IntSet>>& nd_int_sets) { |
842 | if (nd_int_sets.empty()) { |
843 | return {}; |
844 | } |
845 | int n = nd_int_sets.size(); |
846 | int ndim = nd_int_sets[0].size(); |
847 | Array<IntSet> result; |
848 | result.reserve(ndim); |
849 | for (int i = 0; i < ndim; ++i) { |
850 | Array<IntSet> candidates; |
851 | candidates.reserve(n); |
852 | for (int j = 0; j < n; ++j) { |
853 | candidates.push_back(nd_int_sets[j][i]); |
854 | } |
855 | result.push_back(Union(candidates)); |
856 | } |
857 | return result; |
858 | } |
859 | |
860 | IntSet UnionLowerBound(const Array<IntSet>& sets) { |
861 | if (sets.size() == 0) return IntSet::Nothing(); |
862 | if (sets.size() == 1) return sets[0]; |
863 | Analyzer analyzer; |
864 | bool is_first_interval = true; |
865 | PrimExpr min_inclusive{nullptr}; |
866 | PrimExpr max_inclusive(nullptr); |
867 | for (const IntSet& int_set : sets) { |
868 | if (const auto* interval_set = int_set.as<IntervalSetNode>()) { |
869 | PrimExpr new_min_inclusive = interval_set->min_value; |
870 | PrimExpr new_max_inclusive = interval_set->max_value; |
871 | if (is_first_interval) { |
872 | is_first_interval = false; |
873 | min_inclusive = std::move(new_min_inclusive); |
874 | max_inclusive = std::move(new_max_inclusive); |
875 | continue; |
876 | } |
877 | bool bound_1 = is_neg_inf(new_min_inclusive) || is_pos_inf(max_inclusive) || |
878 | analyzer.CanProve(new_min_inclusive <= max_inclusive + 1); |
879 | bool bound_2 = is_neg_inf(min_inclusive) || is_pos_inf(new_max_inclusive) || |
880 | analyzer.CanProve(min_inclusive <= new_max_inclusive + 1); |
881 | if (bound_1 && bound_2) { |
882 | min_inclusive = min(min_inclusive, new_min_inclusive); |
883 | max_inclusive = max(max_inclusive, new_max_inclusive); |
884 | } |
885 | } |
886 | } |
887 | if (is_first_interval) { |
888 | return IntSet::Nothing(); |
889 | } |
890 | return IntSet::Interval(min_inclusive, max_inclusive); |
891 | } |
892 | |
893 | Array<IntSet> UnionRegionLowerBound(const Array<Array<IntSet>>& nd_int_sets) { |
894 | if (nd_int_sets.empty()) { |
895 | return {}; |
896 | } |
897 | int n = nd_int_sets.size(); |
898 | int ndim = nd_int_sets[0].size(); |
899 | Array<IntSet> result; |
900 | result.reserve(ndim); |
901 | for (int i = 0; i < ndim; ++i) { |
902 | Array<IntSet> candidates; |
903 | candidates.reserve(n); |
904 | for (int j = 0; j < n; ++j) { |
905 | candidates.push_back(nd_int_sets[j][i]); |
906 | } |
907 | result.push_back(UnionLowerBound(candidates)); |
908 | } |
909 | return result; |
910 | } |
911 | |
912 | IntSet Intersect(const Array<IntSet>& sets) { |
913 | if (sets.size() == 0) return IntSet::Nothing(); |
914 | if (sets.size() == 1) return sets[0]; |
915 | Analyzer ana; |
916 | IntervalSet x = ToIntervalSet(sets[0]); |
917 | for (size_t i = 1; i < sets.size(); ++i) { |
918 | x = Intersect(&ana, x, ToIntervalSet(sets[i])); |
919 | } |
920 | return IntervalSet(ana.Simplify(x->min_value), ana.Simplify(x->max_value)); |
921 | } |
922 | |
923 | Map<Var, IntSet> ConvertDomMap(const Map<IterVar, IntSet>& dom_map) { |
924 | Map<Var, IntSet> dmap; |
925 | for (auto kv : dom_map) { |
926 | dmap.Set(kv.first->var, kv.second); |
927 | } |
928 | return dmap; |
929 | } |
930 | |
931 | Map<Var, IntSet> ConvertDomMap(const std::unordered_map<const VarNode*, IntSet>& dom_map) { |
932 | Map<Var, IntSet> dmap; |
933 | for (auto kv : dom_map) { |
934 | dmap.Set(GetRef<Var>(kv.first), kv.second); |
935 | } |
936 | return dmap; |
937 | } |
938 | |
939 | IntSet EvalSet(PrimExpr e, const Map<Var, IntSet>& dom_map) { |
940 | Analyzer ana; |
941 | return IntervalSetEvaluator(&ana, dom_map, {}, false).Eval(e); |
942 | } |
943 | |
944 | IntSet IntSet::Vector(PrimExpr x) { |
945 | Analyzer ana; |
946 | Map<Var, IntSet> dmap; |
947 | return IntervalSetEvaluator(&ana, dmap, {}, true).Eval(x); |
948 | } |
949 | |
950 | IntSet EvalSet(PrimExpr e, const Map<IterVar, IntSet>& dom_map) { |
951 | return EvalSet(e, ConvertDomMap(dom_map)); |
952 | } |
953 | |
954 | IntSet EvalSet(PrimExpr e, const std::unordered_map<const VarNode*, IntSet>& dom_map) { |
955 | return EvalSet(e, ConvertDomMap(dom_map)); |
956 | } |
957 | |
958 | IntSet EvalSet(Range r, const Map<Var, IntSet>& dom_map) { |
959 | Analyzer ana; |
960 | if ((r->min->dtype.is_int() || r->min->dtype.is_uint()) && ana.CanProveEqual(r->extent, 1)) { |
961 | return EvalSet(r->min, dom_map); |
962 | } |
963 | IntervalSetEvaluator m(&ana, dom_map); |
964 | // Simplifying first can give tighter bounds if r->min and r->extent share variables |
965 | PrimExpr sum = r->min + r->extent - 1; |
966 | auto res = m.Eval(IntervalSet(r->min, ana.Simplify(sum))); |
967 | return std::move(res); |
968 | } |
969 | |
970 | IntSet EvalSet(Range r, const std::unordered_map<const VarNode*, IntSet>& dom_map) { |
971 | return EvalSet(r, ConvertDomMap(dom_map)); |
972 | } |
973 | |
974 | Array<IntSet> EvalSet(const Array<Range>& region, const Map<Var, IntSet>& dom_map) { |
975 | Analyzer ana; |
976 | IntervalSetEvaluator m(&ana, dom_map); |
977 | Array<IntSet> result; |
978 | result.reserve(region.size()); |
979 | for (const Range& r : region) { |
980 | PrimExpr sum = r->min + (r->extent - 1); |
981 | result.push_back(m.Eval(IntervalSet(r->min, ana.Simplify(sum)))); |
982 | } |
983 | return result; |
984 | } |
985 | |
986 | IntSet EvalSet(IntSet s, const std::unordered_map<const VarNode*, IntSet>& dom_map) { |
987 | Analyzer ana; |
988 | auto dmap = ConvertDomMap(dom_map); |
989 | IntervalSetEvaluator m(&ana, dmap); |
990 | const IntervalSetNode* s_int = s.as<IntervalSetNode>(); |
991 | PrimExpr vmax = s_int->HasUpperBound() ? m.Eval(s_int->max_value).max() : s_int->max_value; |
992 | PrimExpr vmin = s_int->HasLowerBound() ? m.Eval(s_int->min_value).min() : s_int->min_value; |
993 | return IntervalSet(vmin, vmax); |
994 | } |
995 | |
996 | class SubExprIntervalSetEvaluator : public IntervalSetEvaluator { |
997 | public: |
998 | explicit SubExprIntervalSetEvaluator(Analyzer* analyzer, const Map<Var, IntSet>& dom_map) |
999 | : IntervalSetEvaluator(analyzer, dom_map) {} |
1000 | |
1001 | IntervalSet VisitExpr(const PrimExpr& n) final { |
1002 | IntervalSet ret = IntervalSetEvaluator::VisitExpr(n); |
1003 | expr_map[n] = ret; |
1004 | return ret; |
1005 | } |
1006 | |
1007 | ExprIntSetMap expr_map; |
1008 | }; |
1009 | |
1010 | ExprIntSetMap EvalSetForEachSubExpr(PrimExpr e, |
1011 | const std::unordered_map<const VarNode*, IntSet>& dom_map) { |
1012 | Analyzer ana; |
1013 | auto dmap = ConvertDomMap(dom_map); |
1014 | SubExprIntervalSetEvaluator m(&ana, dmap); |
1015 | m.Eval(e); |
1016 | return m.expr_map; |
1017 | } |
1018 | |
1019 | IntSet EvalSet(Range r, const Map<IterVar, IntSet>& dom_map) { |
1020 | return EvalSet(r, ConvertDomMap(dom_map)); |
1021 | } |
1022 | |
1023 | Map<Var, arith::IntSet> AsIntSet(const Map<Var, Range>& var_dom) { |
1024 | Map<Var, arith::IntSet> result; |
1025 | for (auto kv : var_dom) { |
1026 | const Var& var = kv.first; |
1027 | const Range& range = kv.second; |
1028 | result.Set(var, arith::IntSet::FromRange(range)); |
1029 | } |
1030 | return result; |
1031 | } |
1032 | |
1033 | /*! \brief Helper function to convert IterSumExpr to the actual touched range. */ |
1034 | static Optional<IntSet> EvalIterSum(const IterSumExpr& iter_min, const PrimExpr& extent, |
1035 | Analyzer* analyzer) { |
1036 | if (iter_min->args.empty()) { |
1037 | return IntSet::FromMinExtent(iter_min->base, extent); |
1038 | } |
1039 | ICHECK_EQ(iter_min->args.size(), 1) << "The `EvalIterSum` expects fused iter sum expr" ; |
1040 | const IterSplitExpr& split = iter_min->args[0]; |
1041 | if (!analyzer->CanProve(extent >= split->scale)) { |
1042 | return NullOpt; |
1043 | } |
1044 | |
1045 | const PrimExpr& base = iter_min->base; |
1046 | // IterSplitExpr: (source // lower_factor) % extent * scale |
1047 | // where `(source // lower_factor) % extent` is within [0, extent - 1] |
1048 | if (analyzer->CanProve(split->scale < 0)) { |
1049 | // If scale is negative, the var dom is [(extent - 1) * scale, 0] |
1050 | // The total base is `base + (extent - 1) * scale`, |
1051 | // while total extent is `dom_extent + (extent - 1) * (-scale)` |
1052 | const PrimExpr& var_extent = (split->extent - 1) * split->scale; |
1053 | return IntSet::FromMinExtent(base + var_extent, extent - var_extent); |
1054 | } else { |
1055 | // If scale is positive, the var dom is [0, (extent - 1) * scale] |
1056 | // The total dom is [base, dom_extent + (extent - 1) * scale] |
1057 | return IntSet::FromMinExtent(base, extent + (split->extent - 1) * split->scale); |
1058 | } |
1059 | } |
1060 | |
1061 | Optional<Array<IntSet>> EstimateRegionStrictBound(const Array<Range>& region, |
1062 | const Map<Var, Range>& var_dom, |
1063 | const PrimExpr& predicate, Analyzer* analyzer) { |
1064 | int ndim = region.size(); |
1065 | Array<IterSumExpr> iter_sum_exprs{nullptr}; |
1066 | { |
1067 | Array<PrimExpr> affine_indices; |
1068 | affine_indices.reserve(ndim); |
1069 | for (const Range& range : region) { |
1070 | if (!is_const_number(range->extent)) { |
1071 | // dynamic extent is not supported yet. |
1072 | return NullOpt; |
1073 | } |
1074 | affine_indices.push_back(range->min); |
1075 | } |
1076 | auto res = DetectIterMap( |
1077 | /*indices=*/affine_indices, /*input_iters=*/var_dom, |
1078 | /*predicate=*/predicate, /*check_level=*/IterMapLevel::Surjective, analyzer); |
1079 | iter_sum_exprs = res->indices; |
1080 | } |
1081 | if (iter_sum_exprs.empty()) { |
1082 | return NullOpt; |
1083 | } |
1084 | ICHECK_EQ(iter_sum_exprs.size(), ndim); |
1085 | Array<IntSet> result; |
1086 | result.reserve(ndim); |
1087 | for (int i = 0; i < ndim; ++i) { |
1088 | const IterSumExpr& sum_expr = iter_sum_exprs[i]; |
1089 | const Range& range = region[i]; |
1090 | Optional<IntSet> int_set = EvalIterSum(sum_expr, range->extent, analyzer); |
1091 | if (int_set.defined()) { |
1092 | result.push_back(int_set.value()); |
1093 | } else { |
1094 | return NullOpt; |
1095 | } |
1096 | } |
1097 | return result; |
1098 | } |
1099 | |
1100 | Optional<Array<IntSet>> EstimateRegionLowerBound(const Array<Range>& region, |
1101 | const Map<Var, Range>& var_dom, |
1102 | const PrimExpr& predicate, |
1103 | arith::Analyzer* analyzer) { |
1104 | return EstimateRegionStrictBound(region, var_dom, predicate, analyzer); |
1105 | } |
1106 | |
1107 | Array<IntSet> EstimateRegionUpperBound(const Array<Range>& region, const Map<Var, Range>& var_dom, |
1108 | const PrimExpr& predicate, Analyzer* analyzer) { |
1109 | if (Optional<Array<arith::IntSet>> result = EstimateRegionStrictBound( |
1110 | /*region=*/region, |
1111 | /*var_dom=*/var_dom, |
1112 | /*predicate=*/predicate, /*analyzer=*/analyzer)) { |
1113 | return result.value(); |
1114 | } |
1115 | Array<IntSet> result; |
1116 | result.reserve(region.size()); |
1117 | // try estimate each dimension independently |
1118 | for (const Range& range : region) { |
1119 | auto res = DetectIterMap( |
1120 | /*indices=*/{range->min}, /*input_iters=*/var_dom, |
1121 | /*predicate=*/predicate, /*check_level=*/IterMapLevel::Surjective, analyzer); |
1122 | if (!res->indices.empty()) { |
1123 | ICHECK_EQ(res->indices.size(), 1U); |
1124 | IterSumExpr sum_expr = res->indices[0]; |
1125 | |
1126 | // dynamic extent is not supported yet. |
1127 | PrimExpr extent = range->extent; |
1128 | if (!is_const_number(extent)) { |
1129 | IntSet relaxed = EvalSet(extent, AsIntSet(var_dom)); |
1130 | ICHECK(relaxed.HasUpperBound()); |
1131 | extent = relaxed.max(); |
1132 | } |
1133 | |
1134 | if (Optional<IntSet> int_set = EvalIterSum(sum_expr, range->extent, analyzer)) { |
1135 | result.push_back(int_set.value()); |
1136 | continue; |
1137 | } |
1138 | } |
1139 | // fallback to coarse grained evalset |
1140 | result.push_back(EvalSet(range, AsIntSet(var_dom))); |
1141 | } |
1142 | return result; |
1143 | } |
1144 | |
1145 | TVM_REGISTER_NODE_TYPE(IntervalSetNode); |
1146 | |
1147 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
1148 | .set_dispatch<IntervalSetNode>([](const ObjectRef& node, ReprPrinter* p) { |
1149 | auto* op = static_cast<const IntervalSetNode*>(node.get()); |
1150 | p->stream << "IntervalSet" |
1151 | << "[" << op->min_value << ", " << op->max_value << ']'; |
1152 | }); |
1153 | |
1154 | TVM_REGISTER_GLOBAL("arith.intset_single_point" ).set_body_typed(IntSet::SinglePoint); |
1155 | |
1156 | TVM_REGISTER_GLOBAL("arith.intset_vector" ).set_body_typed(IntSet::Vector); |
1157 | |
1158 | TVM_REGISTER_GLOBAL("arith.intset_interval" ).set_body_typed(IntSet::Interval); |
1159 | |
1160 | TVM_REGISTER_GLOBAL("arith.IntervalSetGetMin" ).set_body_method(&IntSet::min); |
1161 | |
1162 | TVM_REGISTER_GLOBAL("arith.IntervalSetGetMax" ).set_body_method(&IntSet::max); |
1163 | |
1164 | TVM_REGISTER_GLOBAL("arith.IntSetIsNothing" ).set_body_method(&IntSet::IsNothing); |
1165 | |
1166 | TVM_REGISTER_GLOBAL("arith.IntSetIsEverything" ).set_body_method(&IntSet::IsEverything); |
1167 | |
1168 | TVM_REGISTER_GLOBAL("arith.EstimateRegionLowerBound" ) |
1169 | .set_body_typed([](Array<Range> region, Map<Var, Range> var_dom, |
1170 | PrimExpr predicate) -> Optional<Array<IntSet>> { |
1171 | Analyzer analyzer; |
1172 | return EstimateRegionLowerBound(region, var_dom, predicate, &analyzer); |
1173 | }); |
1174 | TVM_REGISTER_GLOBAL("arith.EstimateRegionStrictBound" ) |
1175 | .set_body_typed([](Array<Range> region, Map<Var, Range> var_dom, |
1176 | PrimExpr predicate) -> Optional<Array<IntSet>> { |
1177 | Analyzer analyzer; |
1178 | return EstimateRegionStrictBound(region, var_dom, predicate, &analyzer); |
1179 | }); |
1180 | TVM_REGISTER_GLOBAL("arith.EstimateRegionUpperBound" ) |
1181 | .set_body_typed([](Array<Range> region, Map<Var, Range> var_dom, |
1182 | PrimExpr predicate) -> Optional<Array<IntSet>> { |
1183 | Analyzer analyzer; |
1184 | return EstimateRegionUpperBound(region, var_dom, predicate, &analyzer); |
1185 | }); |
1186 | |
1187 | TVM_REGISTER_GLOBAL("arith.PosInf" ).set_body_typed([]() { return SymbolicLimits::pos_inf_; }); |
1188 | TVM_REGISTER_GLOBAL("arith.NegInf" ).set_body_typed([]() { return SymbolicLimits::neg_inf_; }); |
1189 | TVM_REGISTER_GLOBAL("arith.UnionLowerBound" ).set_body_typed(UnionLowerBound); |
1190 | |
1191 | } // namespace arith |
1192 | } // namespace tvm |
1193 | |