1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file modular_set.cc
22 * \brief Modular set analysis
23 */
24#include <tvm/arith/analyzer.h>
25#include <tvm/runtime/registry.h>
26#include <tvm/tir/builtin.h>
27#include <tvm/tir/expr_functor.h>
28#include <tvm/tir/op.h>
29
30#include <limits>
31#include <unordered_map>
32#include <utility>
33
34#include "pattern_match.h"
35
36namespace tvm {
37namespace arith {
38
39using namespace tir;
40
41TVM_REGISTER_NODE_TYPE(ModularSetNode);
42
43ModularSet::ModularSet(int64_t coeff, int64_t base) {
44 auto node = make_object<ModularSetNode>();
45 node->coeff = coeff;
46 node->base = base;
47 // finish construction.
48 data_ = std::move(node);
49}
50
51TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
52 .set_dispatch<ModularSetNode>([](const ObjectRef& node, ReprPrinter* p) {
53 auto* op = static_cast<const ModularSetNode*>(node.get());
54 p->stream << "ModularSet("
55 << "coeff=" << op->coeff << ", base=" << op->base << ')';
56 });
57
58ModularSet MakeModularSet(int64_t coeff, int64_t base) { return ModularSet(coeff, base); }
59
60TVM_REGISTER_GLOBAL("arith.ModularSet").set_body_typed(MakeModularSet);
61
62// internal entry for const int bound
63struct ModularSetAnalyzer::Entry {
64 int64_t coeff{1};
65 int64_t base{0};
66
67 Entry() = default;
68
69 Entry(int64_t coeff, int64_t base) {
70 if (coeff < 0) {
71 // `analyzer->canonical_simplify()` can generate expressions with
72 // negative coefficients (e.g. simplifying `floormod(-i, 2)`
73 // into `floormod(i, -2) * -1`). When this happens, the
74 // ModularSet may enter a constraint based on this expression.
75 //
76 // Handling a negative coeff uses the same sign convention as
77 // canonical_simplify, requiring that
78 // `floormod(var, coeff) == -floormod(var, -coeff)`.
79 coeff *= -1;
80 base *= -1;
81 }
82 this->coeff = coeff;
83 if (coeff != 0) {
84 base = base % coeff;
85 if (base < 0) base += coeff;
86 }
87 this->base = base;
88 }
89
90 bool is_const() const { return coeff == 0; }
91
92 bool operator==(const Entry& other) const { return coeff == other.coeff && base == other.base; }
93
94 bool operator==(const ModularSet& other) const {
95 return other.defined() && coeff == other->coeff && base == other->base;
96 }
97};
98
99class ModularSetAnalyzer::Impl : public ExprFunctor<ModularSetAnalyzer::Entry(const PrimExpr&)> {
100 public:
101 explicit Impl(Analyzer* parent) : parent_(parent) {}
102
103 void Update(const Var& var, const ModularSet& info, bool allow_override) {
104 if (!allow_override) {
105 auto it = var_map_.find(var);
106 if (it != var_map_.end()) {
107 ICHECK(it->second == info)
108 << "Trying to update var \'" << var << "\'"
109 << " with a different const bound: "
110 << "original=" << ModularSet(it->second.coeff, it->second.base) << ", new=" << info;
111 }
112 }
113 var_map_[var] = Entry(info->coeff, info->base);
114 }
115
116 // Detect useful constraints and use them in the analysis scope.
117 std::function<void()> EnterConstraint(const PrimExpr& constraint) {
118 PVar<Var> var;
119 PVar<IntImm> coeff, base;
120 // pattern match interesting constraints
121 if ((truncmod(var, coeff) == base).Match(constraint) ||
122 (floormod(var, coeff) == base).Match(constraint)) {
123 Entry entry(coeff.Eval()->value, base.Eval()->value);
124 return UpdateByIntersect(var.Eval(), entry);
125 }
126 if ((var == base).Match(constraint) || (base == var).Match(constraint)) {
127 Entry entry(1, base.Eval()->value);
128 return UpdateByIntersect(var.Eval(), entry);
129 }
130 return nullptr;
131 }
132
133 // Override visitor behaviors
134 Entry VisitExprDefault_(const Object* op) final { return Everything(); }
135
136 Entry VisitExpr_(const LetNode* op) final {
137 auto it = var_map_.find(op->var);
138 // if the var has not been binded, update the info.
139 if (it == var_map_.end()) {
140 var_map_[op->var] = this->VisitExpr(op->value);
141 Entry ret = VisitExpr(op->body);
142 var_map_.erase(op->var);
143 return ret;
144 } else {
145 return VisitExpr(op->body);
146 }
147 }
148
149 Entry VisitExpr_(const CastNode* op) final { return VisitExpr(op->value); }
150
151 Entry VisitExpr_(const IntImmNode* op) final { return Entry(0, op->value); }
152
153 Entry VisitExpr_(const AddNode* op) final {
154 Entry a = VisitExpr(op->a);
155 Entry b = VisitExpr(op->b);
156 int64_t coeff = ZeroAwareGCD(a.coeff, b.coeff);
157 return Entry(coeff, a.base + b.base);
158 }
159
160 Entry VisitExpr_(const SubNode* op) final {
161 Entry a = VisitExpr(op->a);
162 Entry b = VisitExpr(op->b);
163 int64_t coeff = ZeroAwareGCD(a.coeff, b.coeff);
164 return Entry(coeff, a.base - b.base);
165 }
166
167 Entry VisitExpr_(const MulNode* op) final {
168 Entry a = VisitExpr(op->a);
169 Entry b = VisitExpr(op->b);
170 // Simplification rule, x, y, z are in Z
171 // (p x + n) (q y + m)
172 // -> pq xy + pm x + qn y + mn
173 // -> pq z + pm x + qn y + mn
174 int64_t pq = a.coeff * b.coeff;
175 int64_t pm = a.coeff * b.base;
176 int64_t qn = a.base * b.coeff;
177 int64_t coeff = ZeroAwareGCD(pq, ZeroAwareGCD(pm, qn));
178 return Entry(coeff, a.base * b.base);
179 }
180
181 Entry DivByConst(const PrimExpr& lhs, int64_t val, bool round_down) {
182 Entry a = VisitExpr(lhs);
183 ICHECK_NE(val, 0);
184 if (a.coeff % val == 0) {
185 if (a.base == 0) {
186 // a c x / c -> a x
187 return Entry(std::abs(a.coeff / val), 0);
188 }
189 // positive division have a clear rounding mode.
190 // Only handle case where we clearly know we need to round down.
191 if (a.base > 0 && val > 0 && (round_down || parent_->CanProveGreaterEqual(lhs, 0))) {
192 return Entry(a.coeff / val, a.base / val);
193 }
194 }
195 return Everything();
196 }
197
198 Entry VisitExpr_(const DivNode* op) final {
199 Entry b = VisitExpr(op->b);
200 if (b.is_const()) {
201 return DivByConst(op->a, b.base, false);
202 }
203 return Everything();
204 }
205
206 Entry VisitExpr_(const FloorDivNode* op) final {
207 Entry b = VisitExpr(op->b);
208 if (b.is_const()) {
209 return DivByConst(op->a, b.base, true);
210 }
211 return Everything();
212 }
213
214 Entry VisitExpr_(const MinNode* op) final {
215 Entry a = VisitExpr(op->a);
216 Entry b = VisitExpr(op->b);
217 return Union(a, b);
218 }
219
220 Entry VisitExpr_(const MaxNode* op) final {
221 Entry a = VisitExpr(op->a);
222 Entry b = VisitExpr(op->b);
223 return Union(a, b);
224 }
225
226 Entry VisitExpr_(const SelectNode* op) final {
227 Entry a = VisitExpr(op->true_value);
228 Entry b = VisitExpr(op->false_value);
229 return Union(a, b);
230 }
231
232 Entry ModByConst(const PrimExpr& lhs, int64_t val, bool round_down) {
233 Entry a = VisitExpr(lhs);
234 ICHECK_NE(val, 0);
235 int64_t coeff = ZeroAwareGCD(a.coeff, val);
236 if (a.base % coeff == 0 ||
237 (a.base > 0 && (round_down || parent_->CanProveGreaterEqual(lhs, 0)))) {
238 return Entry(coeff, a.base % coeff);
239 }
240 return Everything();
241 }
242
243 Entry VisitExpr_(const FloorModNode* op) final {
244 Entry b = VisitExpr(op->b);
245 if (b.is_const()) {
246 return ModByConst(op->a, b.base, true);
247 }
248 return Everything();
249 }
250
251 Entry VisitExpr_(const ModNode* op) final {
252 Entry b = VisitExpr(op->b);
253 if (b.is_const()) {
254 return ModByConst(op->a, b.base, false);
255 }
256 return Everything();
257 }
258
259 Entry VisitExpr_(const CallNode* op) final {
260 // only special handle >> which can be
261 // used for index calculation.
262 if (op->op.same_as(tir::builtin::shift_right())) {
263 return VisitRightShift(op);
264 } else if (op->op.same_as(tir::builtin::bitwise_and())) {
265 return VisitBitwiseAnd(op);
266 } else {
267 return Everything();
268 }
269 }
270
271 Entry VisitExpr_(const VarNode* op) final {
272 Var v = GetRef<Var>(op);
273 auto it = var_map_.find(v);
274 if (it != var_map_.end()) {
275 return it->second;
276 } else {
277 return Everything();
278 }
279 }
280
281 Entry VisitRightShift(const CallNode* op) {
282 Entry b = VisitExpr(op->args[1]);
283 // a c x / c -> a x
284 if (b.is_const()) {
285 return DivByConst(op->args[0], static_cast<int64_t>(1) << b.base, true);
286 }
287 return Everything();
288 }
289
290 Entry VisitBitwiseAnd(const CallNode* op) {
291 Entry b = VisitExpr(op->args[1]);
292 if (b.is_const()) {
293 int shift;
294 if (is_const_power_of_two_integer(Integer(b.base + 1), &shift)) {
295 return ModByConst(op->args[0], static_cast<int64_t>(1) << shift, true);
296 }
297 }
298 return Everything();
299 }
300
301 private:
302 /*! \brief pointer to parent. */
303 Analyzer* parent_{nullptr};
304 // internal variable map
305 std::unordered_map<Var, Entry, ObjectPtrHash, ObjectPtrEqual> var_map_;
306 /*!
307 * \brief Update var by intersecting entry with var's current set.
308 * \param var The variable.
309 * \param entry The entry to be updated.
310 * \return The recovery function of the scope.
311 */
312 std::function<void()> UpdateByIntersect(const Var& var, Entry entry) {
313 Entry old = Everything();
314 auto it = var_map_.find(var);
315 if (it != var_map_.end()) {
316 old = it->second;
317 }
318 var_map_[var] = Intersect(old, entry);
319 // reover function.
320 return [this, old, var]() { var_map_[var] = old; };
321 }
322 /*!
323 * \brief Create union of two sets.
324 * \param a The left operand.
325 * \param b the right operand.
326 */
327 static Entry Union(Entry a, Entry b) {
328 // {ax + y} \cup {bz + h} => {gcd(a, b) x + {y or h}}
329 int64_t coeff = ZeroAwareGCD(a.coeff, b.coeff);
330 if (coeff == 0) {
331 if (a.base == b.base) return a;
332 return Everything();
333 }
334 int64_t base0 = a.base % coeff;
335 int64_t base1 = b.base % coeff;
336 if (base0 == base1) {
337 return Entry(coeff, base0);
338 } else {
339 return Entry(ZeroAwareGCD(ZeroAwareGCD(base0, base1), coeff), base0);
340 }
341 }
342
343 /*!
344 * \brief Create interect of two sets.
345 * \param a The left operand.
346 * \param b the right operand.
347 */
348 static Entry Intersect(Entry a, Entry b) {
349 int64_t x, y;
350 int64_t c1 = a.coeff, b1 = a.base, c2 = b.coeff, b2 = b.base;
351 // z = c1 * p + b1
352 // z = c2 * q + b2
353 // c1 * x + c2 * y = gcd(c1, c2)
354 // -> c1 * p - c2 * q = b2 - b1
355 // -> p = (b2 - b1) / gcd * x
356 // -> q = (b2 - b1) / gcd * (-y)
357 // -> z = LCM(x, y) * k + (c1 * p + b1)
358 int64_t gcd = ExtendedEuclidean(c1, c2, &x, &y);
359 int64_t v = b2 - b1;
360 if (v % gcd == 0) {
361 x = v / gcd * x;
362 y = v / gcd * (-y);
363 int64_t coeff = c1 / gcd * c2;
364 return Entry(coeff, x * c1 + b1);
365 } else {
366 return Nothing();
367 }
368 }
369 /*!
370 * \brief return everything dtype can represent.
371 * \return Bound that represent everything dtype can represent.
372 */
373 static Entry Everything() { return Entry(1, 0); }
374 /*!
375 * \brief return an empty set
376 * \return Bound that represent everything dtype can represent.
377 */
378 static Entry Nothing() { return Entry(0, 1); }
379};
380
381ModularSet ModularSetAnalyzer::operator()(const PrimExpr& expr) {
382 Entry ret = impl_->VisitExpr(expr);
383 return ModularSet(ret.coeff, ret.base);
384}
385
386void ModularSetAnalyzer::Update(const Var& var, const ModularSet& info, bool allow_override) {
387 impl_->Update(var, info, allow_override);
388}
389
390std::function<void()> ModularSetAnalyzer::EnterConstraint(const PrimExpr& constraint) {
391 return impl_->EnterConstraint(constraint);
392}
393
394ModularSetAnalyzer::ModularSetAnalyzer(Analyzer* parent) : impl_(new Impl(parent)) {}
395
396ModularSetAnalyzer::~ModularSetAnalyzer() { delete impl_; }
397
398} // namespace arith
399} // namespace tvm
400