1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file modular_set.cc |
22 | * \brief Modular set analysis |
23 | */ |
24 | #include <tvm/arith/analyzer.h> |
25 | #include <tvm/runtime/registry.h> |
26 | #include <tvm/tir/builtin.h> |
27 | #include <tvm/tir/expr_functor.h> |
28 | #include <tvm/tir/op.h> |
29 | |
30 | #include <limits> |
31 | #include <unordered_map> |
32 | #include <utility> |
33 | |
34 | #include "pattern_match.h" |
35 | |
36 | namespace tvm { |
37 | namespace arith { |
38 | |
39 | using namespace tir; |
40 | |
41 | TVM_REGISTER_NODE_TYPE(ModularSetNode); |
42 | |
43 | ModularSet::ModularSet(int64_t coeff, int64_t base) { |
44 | auto node = make_object<ModularSetNode>(); |
45 | node->coeff = coeff; |
46 | node->base = base; |
47 | // finish construction. |
48 | data_ = std::move(node); |
49 | } |
50 | |
51 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
52 | .set_dispatch<ModularSetNode>([](const ObjectRef& node, ReprPrinter* p) { |
53 | auto* op = static_cast<const ModularSetNode*>(node.get()); |
54 | p->stream << "ModularSet(" |
55 | << "coeff=" << op->coeff << ", base=" << op->base << ')'; |
56 | }); |
57 | |
58 | ModularSet MakeModularSet(int64_t coeff, int64_t base) { return ModularSet(coeff, base); } |
59 | |
60 | TVM_REGISTER_GLOBAL("arith.ModularSet" ).set_body_typed(MakeModularSet); |
61 | |
62 | // internal entry for const int bound |
63 | struct ModularSetAnalyzer::Entry { |
64 | int64_t coeff{1}; |
65 | int64_t base{0}; |
66 | |
67 | Entry() = default; |
68 | |
69 | Entry(int64_t coeff, int64_t base) { |
70 | if (coeff < 0) { |
71 | // `analyzer->canonical_simplify()` can generate expressions with |
72 | // negative coefficients (e.g. simplifying `floormod(-i, 2)` |
73 | // into `floormod(i, -2) * -1`). When this happens, the |
74 | // ModularSet may enter a constraint based on this expression. |
75 | // |
76 | // Handling a negative coeff uses the same sign convention as |
77 | // canonical_simplify, requiring that |
78 | // `floormod(var, coeff) == -floormod(var, -coeff)`. |
79 | coeff *= -1; |
80 | base *= -1; |
81 | } |
82 | this->coeff = coeff; |
83 | if (coeff != 0) { |
84 | base = base % coeff; |
85 | if (base < 0) base += coeff; |
86 | } |
87 | this->base = base; |
88 | } |
89 | |
90 | bool is_const() const { return coeff == 0; } |
91 | |
92 | bool operator==(const Entry& other) const { return coeff == other.coeff && base == other.base; } |
93 | |
94 | bool operator==(const ModularSet& other) const { |
95 | return other.defined() && coeff == other->coeff && base == other->base; |
96 | } |
97 | }; |
98 | |
99 | class ModularSetAnalyzer::Impl : public ExprFunctor<ModularSetAnalyzer::Entry(const PrimExpr&)> { |
100 | public: |
101 | explicit Impl(Analyzer* parent) : parent_(parent) {} |
102 | |
103 | void Update(const Var& var, const ModularSet& info, bool allow_override) { |
104 | if (!allow_override) { |
105 | auto it = var_map_.find(var); |
106 | if (it != var_map_.end()) { |
107 | ICHECK(it->second == info) |
108 | << "Trying to update var \'" << var << "\'" |
109 | << " with a different const bound: " |
110 | << "original=" << ModularSet(it->second.coeff, it->second.base) << ", new=" << info; |
111 | } |
112 | } |
113 | var_map_[var] = Entry(info->coeff, info->base); |
114 | } |
115 | |
116 | // Detect useful constraints and use them in the analysis scope. |
117 | std::function<void()> EnterConstraint(const PrimExpr& constraint) { |
118 | PVar<Var> var; |
119 | PVar<IntImm> coeff, base; |
120 | // pattern match interesting constraints |
121 | if ((truncmod(var, coeff) == base).Match(constraint) || |
122 | (floormod(var, coeff) == base).Match(constraint)) { |
123 | Entry entry(coeff.Eval()->value, base.Eval()->value); |
124 | return UpdateByIntersect(var.Eval(), entry); |
125 | } |
126 | if ((var == base).Match(constraint) || (base == var).Match(constraint)) { |
127 | Entry entry(1, base.Eval()->value); |
128 | return UpdateByIntersect(var.Eval(), entry); |
129 | } |
130 | return nullptr; |
131 | } |
132 | |
133 | // Override visitor behaviors |
134 | Entry VisitExprDefault_(const Object* op) final { return Everything(); } |
135 | |
136 | Entry VisitExpr_(const LetNode* op) final { |
137 | auto it = var_map_.find(op->var); |
138 | // if the var has not been binded, update the info. |
139 | if (it == var_map_.end()) { |
140 | var_map_[op->var] = this->VisitExpr(op->value); |
141 | Entry ret = VisitExpr(op->body); |
142 | var_map_.erase(op->var); |
143 | return ret; |
144 | } else { |
145 | return VisitExpr(op->body); |
146 | } |
147 | } |
148 | |
149 | Entry VisitExpr_(const CastNode* op) final { return VisitExpr(op->value); } |
150 | |
151 | Entry VisitExpr_(const IntImmNode* op) final { return Entry(0, op->value); } |
152 | |
153 | Entry VisitExpr_(const AddNode* op) final { |
154 | Entry a = VisitExpr(op->a); |
155 | Entry b = VisitExpr(op->b); |
156 | int64_t coeff = ZeroAwareGCD(a.coeff, b.coeff); |
157 | return Entry(coeff, a.base + b.base); |
158 | } |
159 | |
160 | Entry VisitExpr_(const SubNode* op) final { |
161 | Entry a = VisitExpr(op->a); |
162 | Entry b = VisitExpr(op->b); |
163 | int64_t coeff = ZeroAwareGCD(a.coeff, b.coeff); |
164 | return Entry(coeff, a.base - b.base); |
165 | } |
166 | |
167 | Entry VisitExpr_(const MulNode* op) final { |
168 | Entry a = VisitExpr(op->a); |
169 | Entry b = VisitExpr(op->b); |
170 | // Simplification rule, x, y, z are in Z |
171 | // (p x + n) (q y + m) |
172 | // -> pq xy + pm x + qn y + mn |
173 | // -> pq z + pm x + qn y + mn |
174 | int64_t pq = a.coeff * b.coeff; |
175 | int64_t pm = a.coeff * b.base; |
176 | int64_t qn = a.base * b.coeff; |
177 | int64_t coeff = ZeroAwareGCD(pq, ZeroAwareGCD(pm, qn)); |
178 | return Entry(coeff, a.base * b.base); |
179 | } |
180 | |
181 | Entry DivByConst(const PrimExpr& lhs, int64_t val, bool round_down) { |
182 | Entry a = VisitExpr(lhs); |
183 | ICHECK_NE(val, 0); |
184 | if (a.coeff % val == 0) { |
185 | if (a.base == 0) { |
186 | // a c x / c -> a x |
187 | return Entry(std::abs(a.coeff / val), 0); |
188 | } |
189 | // positive division have a clear rounding mode. |
190 | // Only handle case where we clearly know we need to round down. |
191 | if (a.base > 0 && val > 0 && (round_down || parent_->CanProveGreaterEqual(lhs, 0))) { |
192 | return Entry(a.coeff / val, a.base / val); |
193 | } |
194 | } |
195 | return Everything(); |
196 | } |
197 | |
198 | Entry VisitExpr_(const DivNode* op) final { |
199 | Entry b = VisitExpr(op->b); |
200 | if (b.is_const()) { |
201 | return DivByConst(op->a, b.base, false); |
202 | } |
203 | return Everything(); |
204 | } |
205 | |
206 | Entry VisitExpr_(const FloorDivNode* op) final { |
207 | Entry b = VisitExpr(op->b); |
208 | if (b.is_const()) { |
209 | return DivByConst(op->a, b.base, true); |
210 | } |
211 | return Everything(); |
212 | } |
213 | |
214 | Entry VisitExpr_(const MinNode* op) final { |
215 | Entry a = VisitExpr(op->a); |
216 | Entry b = VisitExpr(op->b); |
217 | return Union(a, b); |
218 | } |
219 | |
220 | Entry VisitExpr_(const MaxNode* op) final { |
221 | Entry a = VisitExpr(op->a); |
222 | Entry b = VisitExpr(op->b); |
223 | return Union(a, b); |
224 | } |
225 | |
226 | Entry VisitExpr_(const SelectNode* op) final { |
227 | Entry a = VisitExpr(op->true_value); |
228 | Entry b = VisitExpr(op->false_value); |
229 | return Union(a, b); |
230 | } |
231 | |
232 | Entry ModByConst(const PrimExpr& lhs, int64_t val, bool round_down) { |
233 | Entry a = VisitExpr(lhs); |
234 | ICHECK_NE(val, 0); |
235 | int64_t coeff = ZeroAwareGCD(a.coeff, val); |
236 | if (a.base % coeff == 0 || |
237 | (a.base > 0 && (round_down || parent_->CanProveGreaterEqual(lhs, 0)))) { |
238 | return Entry(coeff, a.base % coeff); |
239 | } |
240 | return Everything(); |
241 | } |
242 | |
243 | Entry VisitExpr_(const FloorModNode* op) final { |
244 | Entry b = VisitExpr(op->b); |
245 | if (b.is_const()) { |
246 | return ModByConst(op->a, b.base, true); |
247 | } |
248 | return Everything(); |
249 | } |
250 | |
251 | Entry VisitExpr_(const ModNode* op) final { |
252 | Entry b = VisitExpr(op->b); |
253 | if (b.is_const()) { |
254 | return ModByConst(op->a, b.base, false); |
255 | } |
256 | return Everything(); |
257 | } |
258 | |
259 | Entry VisitExpr_(const CallNode* op) final { |
260 | // only special handle >> which can be |
261 | // used for index calculation. |
262 | if (op->op.same_as(tir::builtin::shift_right())) { |
263 | return VisitRightShift(op); |
264 | } else if (op->op.same_as(tir::builtin::bitwise_and())) { |
265 | return VisitBitwiseAnd(op); |
266 | } else { |
267 | return Everything(); |
268 | } |
269 | } |
270 | |
271 | Entry VisitExpr_(const VarNode* op) final { |
272 | Var v = GetRef<Var>(op); |
273 | auto it = var_map_.find(v); |
274 | if (it != var_map_.end()) { |
275 | return it->second; |
276 | } else { |
277 | return Everything(); |
278 | } |
279 | } |
280 | |
281 | Entry VisitRightShift(const CallNode* op) { |
282 | Entry b = VisitExpr(op->args[1]); |
283 | // a c x / c -> a x |
284 | if (b.is_const()) { |
285 | return DivByConst(op->args[0], static_cast<int64_t>(1) << b.base, true); |
286 | } |
287 | return Everything(); |
288 | } |
289 | |
290 | Entry VisitBitwiseAnd(const CallNode* op) { |
291 | Entry b = VisitExpr(op->args[1]); |
292 | if (b.is_const()) { |
293 | int shift; |
294 | if (is_const_power_of_two_integer(Integer(b.base + 1), &shift)) { |
295 | return ModByConst(op->args[0], static_cast<int64_t>(1) << shift, true); |
296 | } |
297 | } |
298 | return Everything(); |
299 | } |
300 | |
301 | private: |
302 | /*! \brief pointer to parent. */ |
303 | Analyzer* parent_{nullptr}; |
304 | // internal variable map |
305 | std::unordered_map<Var, Entry, ObjectPtrHash, ObjectPtrEqual> var_map_; |
306 | /*! |
307 | * \brief Update var by intersecting entry with var's current set. |
308 | * \param var The variable. |
309 | * \param entry The entry to be updated. |
310 | * \return The recovery function of the scope. |
311 | */ |
312 | std::function<void()> UpdateByIntersect(const Var& var, Entry entry) { |
313 | Entry old = Everything(); |
314 | auto it = var_map_.find(var); |
315 | if (it != var_map_.end()) { |
316 | old = it->second; |
317 | } |
318 | var_map_[var] = Intersect(old, entry); |
319 | // reover function. |
320 | return [this, old, var]() { var_map_[var] = old; }; |
321 | } |
322 | /*! |
323 | * \brief Create union of two sets. |
324 | * \param a The left operand. |
325 | * \param b the right operand. |
326 | */ |
327 | static Entry Union(Entry a, Entry b) { |
328 | // {ax + y} \cup {bz + h} => {gcd(a, b) x + {y or h}} |
329 | int64_t coeff = ZeroAwareGCD(a.coeff, b.coeff); |
330 | if (coeff == 0) { |
331 | if (a.base == b.base) return a; |
332 | return Everything(); |
333 | } |
334 | int64_t base0 = a.base % coeff; |
335 | int64_t base1 = b.base % coeff; |
336 | if (base0 == base1) { |
337 | return Entry(coeff, base0); |
338 | } else { |
339 | return Entry(ZeroAwareGCD(ZeroAwareGCD(base0, base1), coeff), base0); |
340 | } |
341 | } |
342 | |
343 | /*! |
344 | * \brief Create interect of two sets. |
345 | * \param a The left operand. |
346 | * \param b the right operand. |
347 | */ |
348 | static Entry Intersect(Entry a, Entry b) { |
349 | int64_t x, y; |
350 | int64_t c1 = a.coeff, b1 = a.base, c2 = b.coeff, b2 = b.base; |
351 | // z = c1 * p + b1 |
352 | // z = c2 * q + b2 |
353 | // c1 * x + c2 * y = gcd(c1, c2) |
354 | // -> c1 * p - c2 * q = b2 - b1 |
355 | // -> p = (b2 - b1) / gcd * x |
356 | // -> q = (b2 - b1) / gcd * (-y) |
357 | // -> z = LCM(x, y) * k + (c1 * p + b1) |
358 | int64_t gcd = ExtendedEuclidean(c1, c2, &x, &y); |
359 | int64_t v = b2 - b1; |
360 | if (v % gcd == 0) { |
361 | x = v / gcd * x; |
362 | y = v / gcd * (-y); |
363 | int64_t coeff = c1 / gcd * c2; |
364 | return Entry(coeff, x * c1 + b1); |
365 | } else { |
366 | return Nothing(); |
367 | } |
368 | } |
369 | /*! |
370 | * \brief return everything dtype can represent. |
371 | * \return Bound that represent everything dtype can represent. |
372 | */ |
373 | static Entry Everything() { return Entry(1, 0); } |
374 | /*! |
375 | * \brief return an empty set |
376 | * \return Bound that represent everything dtype can represent. |
377 | */ |
378 | static Entry Nothing() { return Entry(0, 1); } |
379 | }; |
380 | |
381 | ModularSet ModularSetAnalyzer::operator()(const PrimExpr& expr) { |
382 | Entry ret = impl_->VisitExpr(expr); |
383 | return ModularSet(ret.coeff, ret.base); |
384 | } |
385 | |
386 | void ModularSetAnalyzer::Update(const Var& var, const ModularSet& info, bool allow_override) { |
387 | impl_->Update(var, info, allow_override); |
388 | } |
389 | |
390 | std::function<void()> ModularSetAnalyzer::EnterConstraint(const PrimExpr& constraint) { |
391 | return impl_->EnterConstraint(constraint); |
392 | } |
393 | |
394 | ModularSetAnalyzer::ModularSetAnalyzer(Analyzer* parent) : impl_(new Impl(parent)) {} |
395 | |
396 | ModularSetAnalyzer::~ModularSetAnalyzer() { delete impl_; } |
397 | |
398 | } // namespace arith |
399 | } // namespace tvm |
400 | |