1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/arith/const_int_bound.cc |
22 | */ |
23 | #include <tvm/arith/analyzer.h> |
24 | #include <tvm/runtime/registry.h> |
25 | #include <tvm/tir/builtin.h> |
26 | #include <tvm/tir/expr_functor.h> |
27 | |
28 | #include <algorithm> |
29 | |
30 | #include "constraint_extract.h" |
31 | #include "int_operator.h" |
32 | #include "pattern_match.h" |
33 | |
34 | namespace tvm { |
35 | namespace arith { |
36 | |
37 | using namespace tir; |
38 | |
39 | TVM_REGISTER_NODE_TYPE(ConstIntBoundNode); |
40 | |
41 | ConstIntBound::ConstIntBound(int64_t min_value, int64_t max_value) { |
42 | auto node = make_object<ConstIntBoundNode>(); |
43 | node->min_value = min_value; |
44 | node->max_value = max_value; |
45 | data_ = std::move(node); |
46 | } |
47 | |
48 | ConstIntBound MakeConstIntBound(int64_t min_value, int64_t max_value) { |
49 | return ConstIntBound(min_value, max_value); |
50 | } |
51 | |
52 | TVM_REGISTER_GLOBAL("arith.ConstIntBound" ).set_body_typed(MakeConstIntBound); |
53 | |
54 | inline void PrintBoundValue(std::ostream& os, int64_t val) { |
55 | if (val == ConstIntBound::kPosInf) { |
56 | os << "pos_inf" ; |
57 | } else if (val == ConstIntBound::kNegInf) { |
58 | os << "neg_inf" ; |
59 | } else { |
60 | os << val; |
61 | } |
62 | } |
63 | |
64 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
65 | .set_dispatch<ConstIntBoundNode>([](const ObjectRef& node, ReprPrinter* p) { |
66 | auto* op = static_cast<const ConstIntBoundNode*>(node.get()); |
67 | p->stream << "ConstIntBound[" ; |
68 | PrintBoundValue(p->stream, op->min_value); |
69 | p->stream << ','; |
70 | PrintBoundValue(p->stream, op->max_value); |
71 | p->stream << ']'; |
72 | }); |
73 | |
74 | // internal entry for const int bound |
75 | struct ConstIntBoundAnalyzer::Entry { |
76 | int64_t min_value; |
77 | int64_t max_value; |
78 | |
79 | bool is_const(int64_t value) const { return min_value == max_value && min_value == value; } |
80 | |
81 | bool operator==(const Entry& other) const { |
82 | return min_value == other.min_value && max_value == other.max_value; |
83 | } |
84 | }; |
85 | |
86 | class ConstIntBoundAnalyzer::Impl |
87 | : public ExprFunctor<ConstIntBoundAnalyzer::Entry(const PrimExpr&)> { |
88 | public: |
89 | /*! \brief additional bound info about expr in bound */ |
90 | struct BoundInfo { |
91 | /*! \brief The expr */ |
92 | PrimExpr expr; |
93 | /*! \brief The additional bound */ |
94 | Entry bound; |
95 | |
96 | BoundInfo() {} |
97 | BoundInfo(PrimExpr expr, Entry bound) : expr(expr), bound(bound) {} |
98 | }; |
99 | |
100 | void Bind(const Var& var, const Range& range, bool allow_override) { |
101 | Entry a = VisitExpr(range->min); |
102 | Entry b = VisitExpr(range->extent); |
103 | Entry ret; |
104 | ret.min_value = a.min_value; |
105 | ret.max_value = InfAwareAdd(a.max_value, InfAwareAdd(b.max_value, -1)); |
106 | Update(var, ret, allow_override); |
107 | } |
108 | |
109 | void Update(const Var& var, const Entry& info, bool allow_override) { |
110 | if (!allow_override) { |
111 | auto it = var_map_.find(var); |
112 | if (it != var_map_.end()) { |
113 | ICHECK(it->second == info) |
114 | << "Trying to update var \'" << var << "\'" |
115 | << " with a different const bound: " |
116 | << "original=" << ConstIntBound(it->second.min_value, it->second.max_value) |
117 | << ", new=" << ConstIntBound(info.min_value, info.max_value); |
118 | } |
119 | } |
120 | var_map_[var] = info; |
121 | } |
122 | |
123 | Entry VisitExpr_(const LetNode* op) final { |
124 | auto it = var_map_.find(op->var); |
125 | // if the var has not been binded, update the info. |
126 | if (it == var_map_.end()) { |
127 | var_map_[op->var] = this->VisitExpr(op->value); |
128 | Entry ret = VisitExpr(op->body); |
129 | var_map_.erase(op->var); |
130 | return ret; |
131 | } else { |
132 | return VisitExpr(op->body); |
133 | } |
134 | } |
135 | |
136 | void Update(const Var& var, const ConstIntBound& info, bool allow_override) { |
137 | Update(var, MakeBound(info->min_value, info->max_value), allow_override); |
138 | } |
139 | |
140 | // Override visitor behaviors |
141 | Entry VisitExprDefault_(const Object* op) final { |
142 | return Everything(static_cast<const PrimExprNode*>(op)->dtype); |
143 | } |
144 | |
145 | Entry VisitExpr(const PrimExpr& expr) final { |
146 | Entry res = ExprFunctor::VisitExpr(expr); |
147 | tir::ExprDeepEqual equal; |
148 | // a linear search over additional info |
149 | // assume we won't have a lot of conditions |
150 | for (const BoundInfo& info : additional_info_) { |
151 | if (equal(expr, info.expr)) { |
152 | res = Intersect(res, info.bound); |
153 | } |
154 | } |
155 | if (bound_) { |
156 | auto val = bound_->find(expr); |
157 | if (val != bound_->end()) { |
158 | auto everything = Everything(expr->dtype); |
159 | ICHECK( |
160 | (val->second->min_value == res.min_value && val->second->max_value == res.max_value) || |
161 | (val->second->min_value == everything.min_value && |
162 | val->second->max_value == everything.max_value)) |
163 | << "Detected bound for " << expr << "conflicts with memorization" ; |
164 | } |
165 | (*bound_)[expr] = ConstIntBound(res.min_value, res.max_value); |
166 | } |
167 | return res; |
168 | } |
169 | |
170 | Entry VisitExpr_(const RampNode* op) final { |
171 | // op = {base + i * stride | 0 <= i < lanes} |
172 | // Entry(op) = Union(Entry(base + i * stride) | 0 <= i < lanes) |
173 | // Note that `base + i * stride` is linear w.r.t. `i` |
174 | // Entry(op) = Union(Entry(base + i * stride) | i = 0, i = lanes-1) |
175 | Entry a = VisitExpr(op->base); |
176 | Entry b = VisitExpr(op->base + (op->lanes - 1) * op->stride); |
177 | return Union(a, b); |
178 | } |
179 | |
180 | Entry VisitExpr_(const CastNode* op) final { |
181 | Entry a; |
182 | |
183 | // int(ceil(log2(cast(n,"float64")))) is used as the |
184 | // implementation of topi.math.ceil_log2, and appears in iteration |
185 | // bounds. |
186 | if (auto opt = FindCeilLog2Arg(op)) { |
187 | a = CeilLog2Bounds(opt.value()); |
188 | } else { |
189 | a = VisitExpr(op->value); |
190 | } |
191 | |
192 | Entry b = Everything(op->dtype); |
193 | return Intersect(a, b); |
194 | } |
195 | |
196 | Entry VisitExpr_(const IntImmNode* op) final { return MakeBound(op->value, op->value); } |
197 | |
198 | Entry VisitExpr_(const AddNode* op) final { |
199 | Entry a = VisitExpr(op->a); |
200 | Entry b = VisitExpr(op->b); |
201 | Entry ret; |
202 | ret.min_value = InfAwareAdd(a.min_value, b.min_value); |
203 | ret.max_value = InfAwareAdd(a.max_value, b.max_value); |
204 | return ret; |
205 | } |
206 | |
207 | Entry VisitExpr_(const SubNode* op) final { |
208 | Entry a = VisitExpr(op->a); |
209 | Entry b = VisitExpr(op->b); |
210 | Entry ret; |
211 | ret.min_value = InfAwareAdd(a.min_value, -b.max_value); |
212 | ret.max_value = InfAwareAdd(a.max_value, -b.min_value); |
213 | return ret; |
214 | } |
215 | |
216 | Entry VisitExpr_(const MulNode* op) final { |
217 | Entry a = VisitExpr(op->a); |
218 | Entry b = VisitExpr(op->b); |
219 | return BinaryOpBoundary(a, b, InfAwareMul); |
220 | } |
221 | |
222 | Entry VisitExpr_(const DivNode* op) final { |
223 | Entry a = VisitExpr(op->a); |
224 | Entry b = VisitExpr(op->b); |
225 | ICHECK(!b.is_const(0)) << "divide by zero" ; |
226 | return HandleDivision(a, b, op->dtype, InfAwareDiv); |
227 | } |
228 | |
229 | Entry VisitExpr_(const ModNode* op) final { |
230 | Entry a = VisitExpr(op->a); |
231 | Entry b = VisitExpr(op->b); |
232 | if (b.min_value > 0) { |
233 | int64_t b_max_cap = InfAwareAdd(b.max_value, -1); |
234 | if (a.min_value >= 0) { |
235 | // 0 <= [a_min, a_max] < b_min |
236 | if (a.max_value < b.min_value) return a; |
237 | // other case, we can get close to 0 |
238 | return MakeBound(0, std::min(a.max_value, b_max_cap)); |
239 | } else { |
240 | return MakeBound(std::max(a.min_value, -b_max_cap), |
241 | std::min(std::max(a.max_value, (int64_t)0), b_max_cap)); |
242 | } |
243 | } else { |
244 | ICHECK(!b.is_const(0)) << "mod by zero" ; |
245 | // mod by negative value is rare, |
246 | // and we just use the simpliest rule. |
247 | return Everything(op->dtype); |
248 | } |
249 | } |
250 | |
251 | Entry VisitExpr_(const FloorDivNode* op) final { |
252 | Entry a = VisitExpr(op->a); |
253 | Entry b = VisitExpr(op->b); |
254 | ICHECK(!b.is_const(0)) << "floordiv by zero" ; |
255 | return HandleDivision(a, b, op->dtype, InfAwareFloorDiv); |
256 | } |
257 | |
258 | Entry VisitExpr_(const FloorModNode* op) final { |
259 | /* let a / b = x + y, where x is integer, y \in [0, 1) |
260 | * floormod(a, b) = a - floordiv(a, b) * b |
261 | * floordiv(a, b) = x |
262 | * floormod(a, b) = a - floordiv(a, b) * b |
263 | * = a - x * b |
264 | * = a - (a / b - y) * b |
265 | * = a - a + y * b |
266 | * = y * b |
267 | * note that 0 <= y < 1 |
268 | * when b > 0, 0 <= b * y < b |
269 | * 0 <= b * y <= b - 1 |
270 | * when b < 0, b < b * y <= 0 |
271 | * b + 1 <= b * y <= 0 |
272 | * In all cases, min(0, b + 1) <= b * y <= max(0, b - 1) |
273 | * min(0, b_min + 1) <= b * y <= max(0, b_max - 1) |
274 | * That is, min(0, b_min + 1) <= floormod(a, b) <= max(0, b_max - 1) |
275 | */ |
276 | Entry a = VisitExpr(op->a); |
277 | Entry b = VisitExpr(op->b); |
278 | if (b.min_value > 0) { |
279 | int64_t b_max_cap = InfAwareAdd(b.max_value, -1); |
280 | if (a.min_value >= 0) { |
281 | // 0 <= [a_min, a_max] < b_min |
282 | if (a.max_value < b.min_value) return a; |
283 | // other case, we can get close to 0 |
284 | return MakeBound(0, std::min(a.max_value, b_max_cap)); |
285 | } else { |
286 | return MakeBound(0, b_max_cap); |
287 | } |
288 | } else { |
289 | ICHECK(!b.is_const(0)) << "floormod by zero" ; |
290 | int64_t b_min_cap = InfAwareAdd(b.min_value, 1); |
291 | int64_t b_max_cap = InfAwareAdd(b.max_value, -1); |
292 | return Intersect(MakeBound(std::min(static_cast<int64_t>(0), b_min_cap), |
293 | std::max(static_cast<int64_t>(0), b_max_cap)), |
294 | Everything(op->dtype)); |
295 | } |
296 | } |
297 | |
298 | Entry VisitExpr_(const MinNode* op) final { |
299 | Entry a = VisitExpr(op->a); |
300 | Entry b = VisitExpr(op->b); |
301 | Entry ret; |
302 | ret.min_value = std::min(a.min_value, b.min_value); |
303 | ret.max_value = std::min(a.max_value, b.max_value); |
304 | return ret; |
305 | } |
306 | |
307 | Entry VisitExpr_(const MaxNode* op) final { |
308 | Entry a = VisitExpr(op->a); |
309 | Entry b = VisitExpr(op->b); |
310 | Entry ret; |
311 | ret.min_value = std::max(a.min_value, b.min_value); |
312 | ret.max_value = std::max(a.max_value, b.max_value); |
313 | return ret; |
314 | } |
315 | |
316 | Entry VisitExpr_(const SelectNode* op) final { |
317 | Entry a = VisitExpr(op->true_value); |
318 | Entry b = VisitExpr(op->false_value); |
319 | return Union(a, b); |
320 | } |
321 | |
322 | Entry VisitExpr_(const CallNode* op) final { |
323 | // only special handle >> and & which can be |
324 | // used for index calculation. |
325 | |
326 | if (op->op.same_as(tir::builtin::shift_right())) { |
327 | return VisitRightShift(op); |
328 | } else if (op->op.same_as(tir::builtin::shift_left())) { |
329 | return VisitLeftShift(op); |
330 | } else if (op->op.same_as(tir::builtin::bitwise_and())) { |
331 | return VisitBitwiseAnd(op); |
332 | } else { |
333 | return Everything(op->dtype); |
334 | } |
335 | } |
336 | |
337 | Entry VisitExpr_(const VarNode* op) final { |
338 | Var v = GetRef<Var>(op); |
339 | auto it = var_map_.find(v); |
340 | if (it != var_map_.end()) { |
341 | return it->second; |
342 | } else { |
343 | return Everything(op->dtype); |
344 | } |
345 | } |
346 | |
347 | Entry VisitExpr_(const SizeVarNode* op) final { |
348 | SizeVar v = GetRef<SizeVar>(op); |
349 | auto it = var_map_.find(v); |
350 | if (it != var_map_.end()) { |
351 | return it->second; |
352 | } else { |
353 | return MakeBound(0, kPosInf); |
354 | } |
355 | } |
356 | |
357 | Entry VisitLeftShift(const CallNode* op) { |
358 | Entry a = VisitExpr(op->args[0]); |
359 | Entry b = VisitExpr(op->args[1]); |
360 | |
361 | if (a.min_value < 0 || b.min_value < 0) { |
362 | // If either operand can negative, we may run into undefined |
363 | // behavior for some targets. In these cases, avoid making any |
364 | // assumptions about the result. |
365 | return Everything(op->dtype); |
366 | } |
367 | |
368 | return BinaryOpBoundary(a, b, InfAwareLeftShift); |
369 | } |
370 | |
371 | Entry VisitRightShift(const CallNode* op) { |
372 | Entry a = VisitExpr(op->args[0]); |
373 | Entry b = VisitExpr(op->args[1]); |
374 | return BinaryOpBoundary(a, b, InfAwareRightShift); |
375 | } |
376 | |
377 | Entry VisitBitwiseAnd(const CallNode* op) { |
378 | Entry a = VisitExpr(op->args[0]); |
379 | Entry b = VisitExpr(op->args[1]); |
380 | // handle positive index case. |
381 | if (a.min_value >= 0 && b.min_value >= 0) { |
382 | return MakeBound(0, std::min(a.max_value, b.max_value)); |
383 | } else { |
384 | if (b.min_value >= 0) { |
385 | return MakeBound(0, b.max_value); |
386 | } |
387 | if (a.min_value >= 0) { |
388 | return MakeBound(0, a.max_value); |
389 | } |
390 | return Everything(op->dtype); |
391 | } |
392 | } |
393 | |
394 | std::function<void()> EnterConstraint(const PrimExpr& constraint) { |
395 | std::vector<BoundInfo> info = DetectBoundInfo(constraint); |
396 | if (info.size() == 0) return nullptr; |
397 | size_t old_size = additional_info_.size(); |
398 | additional_info_.insert(additional_info_.end(), info.begin(), info.end()); |
399 | size_t new_size = old_size + info.size(); |
400 | auto frecover = [old_size, new_size, this]() { |
401 | ICHECK_EQ(additional_info_.size(), new_size); |
402 | additional_info_.resize(old_size); |
403 | }; |
404 | return frecover; |
405 | } |
406 | |
407 | private: |
408 | friend class ConstIntBoundAnalyzer; |
409 | // internal variable map |
410 | std::unordered_map<Var, Entry, ObjectPtrHash, ObjectPtrEqual> var_map_; |
411 | // additional bound info |
412 | std::vector<BoundInfo> additional_info_; |
413 | // look up table for memorization |
414 | BoundMapType* bound_{nullptr}; |
415 | // constants: the limit value means umlimited |
416 | // NOTE: kNegInf/kPosInf are used to represent infinity. |
417 | static const constexpr int64_t kNegInf = ConstIntBound::kNegInf; |
418 | static const constexpr int64_t kPosInf = ConstIntBound::kPosInf; |
419 | static_assert(-kNegInf == kPosInf, "invariant of inf" ); |
420 | // internal helper functions |
421 | /*! |
422 | * \brief Get boundary of binary op who are monotonic wrt to one argument. |
423 | * \param a The entry of the left operand. |
424 | * \param b The entry of the right operand. |
425 | * \param op The operator. |
426 | * \tparam F the operator function type. |
427 | * \return The result. |
428 | */ |
429 | template <typename F> |
430 | static Entry BinaryOpBoundary(Entry a, Entry b, const F& op) { |
431 | Entry ret; |
432 | // The boundary point must be shihft of the original boundary. |
433 | int64_t v1 = op(a.min_value, b.min_value); |
434 | int64_t v2 = op(a.max_value, b.max_value); |
435 | int64_t v3 = op(a.min_value, b.max_value); |
436 | int64_t v4 = op(a.max_value, b.min_value); |
437 | ret.min_value = std::min(std::min(std::min(v1, v2), v3), v4); |
438 | ret.max_value = std::max(std::max(std::max(v1, v2), v3), v4); |
439 | return ret; |
440 | } |
441 | /*! |
442 | * \brief Get value boundaries of division (e.g. Div or FloorDiv). |
443 | * \param a The entry of the left operand. |
444 | * \param b The entry of the right operand. |
445 | * \param dt The data type of the division operator. |
446 | * \param op The division operator. |
447 | * \tparam F the operator function type. |
448 | * \return The result. |
449 | */ |
450 | template <typename F> |
451 | static Entry HandleDivision(Entry a, Entry b, DataType dt, const F& op) { |
452 | // Here we have a / b. |
453 | // The largest value of the division will be for the smallest (with |
454 | // respect to the absolute value) value of b. If the range of b starts |
455 | // at a negative value and ends at a positive one, narrow it down to |
456 | // be closer to 0, because BinaryOpBoundary only checks end-points of |
457 | // the domain ranges. |
458 | |
459 | // If the range of b contains 0, then some infinity will be involved |
460 | if (b.min_value <= 0 && 0 <= b.max_value && dt.is_int()) { |
461 | Entry b_neg = b.min_value < 0 ? MakeBound(b.min_value, -1) : Everything(dt); |
462 | Entry b_pos = b.max_value > 0 ? MakeBound(1, b.max_value) : Everything(dt); |
463 | |
464 | Entry e_neg = BinaryOpBoundary(a, b_neg, op); |
465 | Entry e_pos = BinaryOpBoundary(a, b_pos, op); |
466 | |
467 | return MakeBound(std::min(e_neg.min_value, e_pos.min_value), |
468 | std::max(e_neg.max_value, e_pos.max_value)); |
469 | } else if (b.min_value == 0 && dt.is_uint()) { |
470 | // uints only have one sided bounds |
471 | Entry assumed_b = MakeBound(1, b.max_value); |
472 | return BinaryOpBoundary(a, assumed_b, op); |
473 | } |
474 | // If the range of b does not have 0, use BinaryOpBoundary. |
475 | return BinaryOpBoundary(a, b, op); |
476 | } |
477 | /*! |
478 | * \brief Compute x + y, aware of inf. |
479 | * \param x The left operand. |
480 | * \param y The right operand. |
481 | * \return the result. |
482 | */ |
483 | static int64_t InfAwareAdd(int64_t x, int64_t y) { |
484 | if (x == kPosInf) { |
485 | ICHECK(y != kNegInf); |
486 | return kPosInf; |
487 | } |
488 | if (x == kNegInf) { |
489 | ICHECK(y != kPosInf); |
490 | return kNegInf; |
491 | } |
492 | if (y == kPosInf || y == kNegInf) return y; |
493 | if (WillOverflow<AddNode>(x, y, kNegInf, kPosInf)) { |
494 | if (x > 0) return kPosInf; |
495 | return kNegInf; |
496 | } |
497 | return x + y; |
498 | } |
499 | /*! |
500 | * \brief Compute x * y, aware of inf. |
501 | * \param x The left operand. |
502 | * \param y The right operand. |
503 | * \return the result. |
504 | */ |
505 | static int64_t InfAwareMul(int64_t x, int64_t y) { |
506 | if (!WillOverflow<MulNode>(x, y, kNegInf, kPosInf)) return x * y; |
507 | if ((x > 0 && y > 0) || (x < 0 && y < 0)) return kPosInf; |
508 | return kNegInf; |
509 | } |
510 | /*! |
511 | * \brief Compute x / y, aware of inf. |
512 | * \param x The left operand. |
513 | * \param y The right operand. |
514 | * \return the result. |
515 | */ |
516 | static int64_t InfAwareDiv(int64_t x, int64_t y) { |
517 | ICHECK_NE(y, 0); |
518 | if (x == kPosInf || x == kNegInf) { |
519 | if (y > 0) return x; |
520 | return -x; |
521 | } |
522 | return x / y; |
523 | } |
524 | /*! |
525 | * \brief Compute floodiv(x, y), aware of inf. |
526 | * \param x The left operand. |
527 | * \param y The right operand. |
528 | * \return the result. |
529 | */ |
530 | static int64_t InfAwareFloorDiv(int64_t x, int64_t y) { |
531 | ICHECK_NE(y, 0); |
532 | if (x == kPosInf || x == kNegInf) { |
533 | if (y > 0) return x; |
534 | return -x; |
535 | } |
536 | return floordiv(x, y); |
537 | } |
538 | /*! |
539 | * \brief Compute x << y, aware of inf. |
540 | * \param x The left operand. |
541 | * \param y The right operand. |
542 | * \return the result. |
543 | */ |
544 | static int64_t InfAwareLeftShift(int64_t x, int64_t y) { |
545 | if (x == kPosInf || x == kNegInf) return x; |
546 | |
547 | // Can be replaced with std::bit_width in C++20 |
548 | auto bit_width = [](int64_t as_signed) { |
549 | uint64_t val = std::abs(as_signed); |
550 | int num_bits = 0; |
551 | while (val) { |
552 | ++num_bits; |
553 | val >>= 1; |
554 | } |
555 | return num_bits; |
556 | }; |
557 | int x_bits = bit_width(x); |
558 | if (x_bits + y < 64) { |
559 | return x << y; |
560 | } else { |
561 | return kPosInf; |
562 | } |
563 | } |
564 | /*! |
565 | * \brief Compute x >> y, aware of inf. |
566 | * \param x The left operand. |
567 | * \param y The right operand. |
568 | * \return the result. |
569 | */ |
570 | static int64_t InfAwareRightShift(int64_t x, int64_t y) { |
571 | if (x == kPosInf || x == kNegInf) return x; |
572 | return x >> y; |
573 | } |
574 | /*! |
575 | * \brief Make a new bound entry. |
576 | */ |
577 | static Entry MakeBound(int64_t min_value, int64_t max_value) { |
578 | Entry e; |
579 | e.min_value = (min_value == kPosInf) ? min_value - 1 : min_value; |
580 | e.max_value = (max_value == kNegInf) ? max_value + 1 : max_value; |
581 | return e; |
582 | } |
583 | /*! |
584 | * \brief Create union of two sets. |
585 | * \param a The left operand. |
586 | * \param b the right operand. |
587 | */ |
588 | static Entry Union(Entry a, Entry b) { |
589 | Entry ret; |
590 | ret.min_value = std::min(a.min_value, b.min_value); |
591 | ret.max_value = std::max(a.max_value, b.max_value); |
592 | return ret; |
593 | } |
594 | /*! |
595 | * \brief Create intersect of two sets. |
596 | * \param a The left operand. |
597 | * \param b the right operand. |
598 | */ |
599 | static Entry Intersect(Entry a, Entry b) { |
600 | Entry ret; |
601 | ret.min_value = std::max(a.min_value, b.min_value); |
602 | ret.max_value = std::min(a.max_value, b.max_value); |
603 | return ret; |
604 | } |
605 | /*! |
606 | * \brief return everything dtype can represent. |
607 | * \param dtype The data type. |
608 | * \return Bound that represent everything dtype can represent. |
609 | */ |
610 | static Entry Everything(DataType dtype) { |
611 | if (!dtype.is_int() && !dtype.is_uint()) { |
612 | return MakeBound(kNegInf, kPosInf); |
613 | } |
614 | Entry ret; |
615 | int64_t vbits = dtype.bits() - static_cast<int>(dtype.is_int()); |
616 | if (dtype.is_uint()) { |
617 | ret.min_value = 0; |
618 | } else { |
619 | if (vbits >= 63) { |
620 | ret.min_value = kNegInf; |
621 | } else { |
622 | ret.min_value = -(static_cast<int64_t>(1) << vbits); |
623 | } |
624 | } |
625 | if (vbits >= 63) { |
626 | ret.max_value = kPosInf; |
627 | } else { |
628 | ret.max_value = (static_cast<int64_t>(1) << vbits) - 1; |
629 | } |
630 | return ret; |
631 | } |
632 | |
633 | /*! |
634 | * \brief Detect additional constant bound from cond, if any |
635 | * \param cond The constraint condition. |
636 | * \return List of detected bounds. |
637 | */ |
638 | static std::vector<BoundInfo> DetectBoundInfo(const PrimExpr& cond) { |
639 | PVar<PrimExpr> x, y; |
640 | PVar<IntImm> c; |
641 | |
642 | std::vector<BoundInfo> info; |
643 | auto add_info = [&](const PrimExpr& expr, int64_t min_value, int64_t max_value) { |
644 | // If the conditional is comparing two integers, do not assign a |
645 | // value to them. |
646 | if (!expr->IsInstance<IntImmNode>()) { |
647 | info.push_back(BoundInfo(expr, MakeBound(min_value, max_value))); |
648 | } |
649 | }; |
650 | |
651 | for (const auto& subexpr : ExtractConstraints(cond)) { |
652 | // NOTE: The canonical form always uses <= or <, but a |
653 | // user-supplied constraint from the python API might not be |
654 | // canonicalized. |
655 | if ((c <= x).Match(subexpr) || (x >= c).Match(subexpr)) { |
656 | add_info(x.Eval(), c.Eval()->value, kPosInf); |
657 | } else if ((c < x).Match(subexpr) || (x > c).Match(subexpr)) { |
658 | add_info(x.Eval(), c.Eval()->value + 1, kPosInf); |
659 | } else if ((x <= c).Match(subexpr) || (x >= c).Match(subexpr)) { |
660 | add_info(x.Eval(), kNegInf, c.Eval()->value); |
661 | } else if ((x < c).Match(subexpr) || (c > x).Match(subexpr)) { |
662 | add_info(x.Eval(), kNegInf, c.Eval()->value - 1); |
663 | } else if ((x == c).Match(subexpr) || (c == x).Match(subexpr)) { |
664 | add_info(x.Eval(), c.Eval()->value, c.Eval()->value); |
665 | } |
666 | } |
667 | |
668 | return info; |
669 | } |
670 | |
671 | /*! |
672 | * \brief Extract the argument from int(ceil(log2(arg))) |
673 | * |
674 | * This expression is used as the implementation of |
675 | * topi.math.ceil_log2, and can appear in iteration bounds. |
676 | */ |
677 | static Optional<PrimExpr> FindCeilLog2Arg(const CastNode* op) { |
678 | if (op->dtype.is_int()) { |
679 | if (auto as_call = op->value.as<CallNode>()) { |
680 | if (as_call->op.same_as(Op::Get("tir.ceil" ))) { |
681 | PrimExpr ceil_arg = as_call->args[0]; |
682 | if (auto arg_call = ceil_arg.as<CallNode>()) { |
683 | if (arg_call->op.same_as(Op::Get("tir.log2" ))) { |
684 | PrimExpr log_arg = arg_call->args[0]; |
685 | return log_arg; |
686 | } |
687 | } |
688 | } |
689 | } |
690 | } |
691 | return NullOpt; |
692 | } |
693 | |
694 | /*! \brief Propagate constraints through ceil(log2(arg)) |
695 | * |
696 | * Helper function for CastNode visitor |
697 | */ |
698 | Entry CeilLog2Bounds(PrimExpr arg) { |
699 | if (auto as_float = arg.as<FloatImmNode>()) { |
700 | // A cast from int to float may have already been simplified |
701 | // out. Normally we don't inspect floating-point arguments, but here we can |
702 | int64_t val = std::ceil(std::log2(as_float->value)); |
703 | return MakeBound(val, val); |
704 | } else { |
705 | Entry arg_bounds = VisitExpr(arg); |
706 | return MakeBound(std::ceil(std::log2(arg_bounds.min_value)), |
707 | std::ceil(std::log2(arg_bounds.max_value))); |
708 | } |
709 | } |
710 | }; |
711 | |
712 | ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr) const { |
713 | Entry ret = impl_->VisitExpr(expr); |
714 | return ConstIntBound(ret.min_value, ret.max_value); |
715 | } |
716 | |
717 | ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr, BoundMapType* bound) { |
718 | impl_->bound_ = bound; |
719 | Entry ret = impl_->VisitExpr(expr); |
720 | impl_->bound_ = nullptr; |
721 | return ConstIntBound(ret.min_value, ret.max_value); |
722 | } |
723 | |
724 | void ConstIntBoundAnalyzer::Update(const Var& var, const ConstIntBound& info, bool allow_override) { |
725 | impl_->Update(var, info, allow_override); |
726 | } |
727 | |
728 | void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range, bool allow_override) { |
729 | impl_->Bind(var, range, allow_override); |
730 | } |
731 | |
732 | std::function<void()> ConstIntBoundAnalyzer::EnterConstraint(const PrimExpr& constraint) { |
733 | return impl_->EnterConstraint(constraint); |
734 | } |
735 | |
736 | ConstIntBoundAnalyzer::ConstIntBoundAnalyzer(Analyzer* parent) : impl_(new Impl()) {} |
737 | |
738 | ConstIntBoundAnalyzer::~ConstIntBoundAnalyzer() { delete impl_; } |
739 | |
740 | } // namespace arith |
741 | } // namespace tvm |
742 | |