1 | /* |
---|---|

2 | * Licensed to the Apache Software Foundation (ASF) under one |

3 | * or more contributor license agreements. See the NOTICE file |

4 | * distributed with this work for additional information |

5 | * regarding copyright ownership. The ASF licenses this file |

6 | * to you under the Apache License, Version 2.0 (the |

7 | * "License"); you may not use this file except in compliance |

8 | * with the License. You may obtain a copy of the License at |

9 | * |

10 | * http://www.apache.org/licenses/LICENSE-2.0 |

11 | * |

12 | * Unless required by applicable law or agreed to in writing, |

13 | * software distributed under the License is distributed on an |

14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |

15 | * KIND, either express or implied. See the License for the |

16 | * specific language governing permissions and limitations |

17 | * under the License. |

18 | */ |

19 | |

20 | /*! |

21 | * \file modular_set.cc |

22 | * \brief Modular set analysis |

23 | */ |

24 | #include <tvm/arith/analyzer.h> |

25 | #include <tvm/runtime/registry.h> |

26 | #include <tvm/tir/builtin.h> |

27 | #include <tvm/tir/expr_functor.h> |

28 | #include <tvm/tir/op.h> |

29 | |

30 | #include <limits> |

31 | #include <unordered_map> |

32 | #include <utility> |

33 | |

34 | #include "pattern_match.h" |

35 | |

36 | namespace tvm { |

37 | namespace arith { |

38 | |

39 | using namespace tir; |

40 | |

41 | TVM_REGISTER_NODE_TYPE(ModularSetNode); |

42 | |

43 | ModularSet::ModularSet(int64_t coeff, int64_t base) { |

44 | auto node = make_object<ModularSetNode>(); |

45 | node->coeff = coeff; |

46 | node->base = base; |

47 | // finish construction. |

48 | data_ = std::move(node); |

49 | } |

50 | |

51 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |

52 | .set_dispatch<ModularSetNode>([](const ObjectRef& node, ReprPrinter* p) { |

53 | auto* op = static_cast<const ModularSetNode*>(node.get()); |

54 | p->stream << "ModularSet(" |

55 | << "coeff="<< op->coeff << ", base="<< op->base << ')'; |

56 | }); |

57 | |

58 | ModularSet MakeModularSet(int64_t coeff, int64_t base) { return ModularSet(coeff, base); } |

59 | |

60 | TVM_REGISTER_GLOBAL("arith.ModularSet").set_body_typed(MakeModularSet); |

61 | |

62 | // internal entry for const int bound |

63 | struct ModularSetAnalyzer::Entry { |

64 | int64_t coeff{1}; |

65 | int64_t base{0}; |

66 | |

67 | Entry() = default; |

68 | |

69 | Entry(int64_t coeff, int64_t base) { |

70 | if (coeff < 0) { |

71 | // `analyzer->canonical_simplify()` can generate expressions with |

72 | // negative coefficients (e.g. simplifying `floormod(-i, 2)` |

73 | // into `floormod(i, -2) * -1`). When this happens, the |

74 | // ModularSet may enter a constraint based on this expression. |

75 | // |

76 | // Handling a negative coeff uses the same sign convention as |

77 | // canonical_simplify, requiring that |

78 | // `floormod(var, coeff) == -floormod(var, -coeff)`. |

79 | coeff *= -1; |

80 | base *= -1; |

81 | } |

82 | this->coeff = coeff; |

83 | if (coeff != 0) { |

84 | base = base % coeff; |

85 | if (base < 0) base += coeff; |

86 | } |

87 | this->base = base; |

88 | } |

89 | |

90 | bool is_const() const { return coeff == 0; } |

91 | |

92 | bool operator==(const Entry& other) const { return coeff == other.coeff && base == other.base; } |

93 | |

94 | bool operator==(const ModularSet& other) const { |

95 | return other.defined() && coeff == other->coeff && base == other->base; |

96 | } |

97 | }; |

98 | |

99 | class ModularSetAnalyzer::Impl : public ExprFunctor<ModularSetAnalyzer::Entry(const PrimExpr&)> { |

100 | public: |

101 | explicit Impl(Analyzer* parent) : parent_(parent) {} |

102 | |

103 | void Update(const Var& var, const ModularSet& info, bool allow_override) { |

104 | if (!allow_override) { |

105 | auto it = var_map_.find(var); |

106 | if (it != var_map_.end()) { |

107 | ICHECK(it->second == info) |

108 | << "Trying to update var \'"<< var << "\'" |

109 | << " with a different const bound: " |

110 | << "original="<< ModularSet(it->second.coeff, it->second.base) << ", new="<< info; |

111 | } |

112 | } |

113 | var_map_[var] = Entry(info->coeff, info->base); |

114 | } |

115 | |

116 | // Detect useful constraints and use them in the analysis scope. |

117 | std::function<void()> EnterConstraint(const PrimExpr& constraint) { |

118 | PVar<Var> var; |

119 | PVar<IntImm> coeff, base; |

120 | // pattern match interesting constraints |

121 | if ((truncmod(var, coeff) == base).Match(constraint) || |

122 | (floormod(var, coeff) == base).Match(constraint)) { |

123 | Entry entry(coeff.Eval()->value, base.Eval()->value); |

124 | return UpdateByIntersect(var.Eval(), entry); |

125 | } |

126 | if ((var == base).Match(constraint) || (base == var).Match(constraint)) { |

127 | Entry entry(1, base.Eval()->value); |

128 | return UpdateByIntersect(var.Eval(), entry); |

129 | } |

130 | return nullptr; |

131 | } |

132 | |

133 | // Override visitor behaviors |

134 | Entry VisitExprDefault_(const Object* op) final { return Everything(); } |

135 | |

136 | Entry VisitExpr_(const LetNode* op) final { |

137 | auto it = var_map_.find(op->var); |

138 | // if the var has not been binded, update the info. |

139 | if (it == var_map_.end()) { |

140 | var_map_[op->var] = this->VisitExpr(op->value); |

141 | Entry ret = VisitExpr(op->body); |

142 | var_map_.erase(op->var); |

143 | return ret; |

144 | } else { |

145 | return VisitExpr(op->body); |

146 | } |

147 | } |

148 | |

149 | Entry VisitExpr_(const CastNode* op) final { return VisitExpr(op->value); } |

150 | |

151 | Entry VisitExpr_(const IntImmNode* op) final { return Entry(0, op->value); } |

152 | |

153 | Entry VisitExpr_(const AddNode* op) final { |

154 | Entry a = VisitExpr(op->a); |

155 | Entry b = VisitExpr(op->b); |

156 | int64_t coeff = ZeroAwareGCD(a.coeff, b.coeff); |

157 | return Entry(coeff, a.base + b.base); |

158 | } |

159 | |

160 | Entry VisitExpr_(const SubNode* op) final { |

161 | Entry a = VisitExpr(op->a); |

162 | Entry b = VisitExpr(op->b); |

163 | int64_t coeff = ZeroAwareGCD(a.coeff, b.coeff); |

164 | return Entry(coeff, a.base - b.base); |

165 | } |

166 | |

167 | Entry VisitExpr_(const MulNode* op) final { |

168 | Entry a = VisitExpr(op->a); |

169 | Entry b = VisitExpr(op->b); |

170 | // Simplification rule, x, y, z are in Z |

171 | // (p x + n) (q y + m) |

172 | // -> pq xy + pm x + qn y + mn |

173 | // -> pq z + pm x + qn y + mn |

174 | int64_t pq = a.coeff * b.coeff; |

175 | int64_t pm = a.coeff * b.base; |

176 | int64_t qn = a.base * b.coeff; |

177 | int64_t coeff = ZeroAwareGCD(pq, ZeroAwareGCD(pm, qn)); |

178 | return Entry(coeff, a.base * b.base); |

179 | } |

180 | |

181 | Entry DivByConst(const PrimExpr& lhs, int64_t val, bool round_down) { |

182 | Entry a = VisitExpr(lhs); |

183 | ICHECK_NE(val, 0); |

184 | if (a.coeff % val == 0) { |

185 | if (a.base == 0) { |

186 | // a c x / c -> a x |

187 | return Entry(std::abs(a.coeff / val), 0); |

188 | } |

189 | // positive division have a clear rounding mode. |

190 | // Only handle case where we clearly know we need to round down. |

191 | if (a.base > 0 && val > 0 && (round_down || parent_->CanProveGreaterEqual(lhs, 0))) { |

192 | return Entry(a.coeff / val, a.base / val); |

193 | } |

194 | } |

195 | return Everything(); |

196 | } |

197 | |

198 | Entry VisitExpr_(const DivNode* op) final { |

199 | Entry b = VisitExpr(op->b); |

200 | if (b.is_const()) { |

201 | return DivByConst(op->a, b.base, false); |

202 | } |

203 | return Everything(); |

204 | } |

205 | |

206 | Entry VisitExpr_(const FloorDivNode* op) final { |

207 | Entry b = VisitExpr(op->b); |

208 | if (b.is_const()) { |

209 | return DivByConst(op->a, b.base, true); |

210 | } |

211 | return Everything(); |

212 | } |

213 | |

214 | Entry VisitExpr_(const MinNode* op) final { |

215 | Entry a = VisitExpr(op->a); |

216 | Entry b = VisitExpr(op->b); |

217 | return Union(a, b); |

218 | } |

219 | |

220 | Entry VisitExpr_(const MaxNode* op) final { |

221 | Entry a = VisitExpr(op->a); |

222 | Entry b = VisitExpr(op->b); |

223 | return Union(a, b); |

224 | } |

225 | |

226 | Entry VisitExpr_(const SelectNode* op) final { |

227 | Entry a = VisitExpr(op->true_value); |

228 | Entry b = VisitExpr(op->false_value); |

229 | return Union(a, b); |

230 | } |

231 | |

232 | Entry ModByConst(const PrimExpr& lhs, int64_t val, bool round_down) { |

233 | Entry a = VisitExpr(lhs); |

234 | ICHECK_NE(val, 0); |

235 | int64_t coeff = ZeroAwareGCD(a.coeff, val); |

236 | if (a.base % coeff == 0 || |

237 | (a.base > 0 && (round_down || parent_->CanProveGreaterEqual(lhs, 0)))) { |

238 | return Entry(coeff, a.base % coeff); |

239 | } |

240 | return Everything(); |

241 | } |

242 | |

243 | Entry VisitExpr_(const FloorModNode* op) final { |

244 | Entry b = VisitExpr(op->b); |

245 | if (b.is_const()) { |

246 | return ModByConst(op->a, b.base, true); |

247 | } |

248 | return Everything(); |

249 | } |

250 | |

251 | Entry VisitExpr_(const ModNode* op) final { |

252 | Entry b = VisitExpr(op->b); |

253 | if (b.is_const()) { |

254 | return ModByConst(op->a, b.base, false); |

255 | } |

256 | return Everything(); |

257 | } |

258 | |

259 | Entry VisitExpr_(const CallNode* op) final { |

260 | // only special handle >> which can be |

261 | // used for index calculation. |

262 | if (op->op.same_as(tir::builtin::shift_right())) { |

263 | return VisitRightShift(op); |

264 | } else if (op->op.same_as(tir::builtin::bitwise_and())) { |

265 | return VisitBitwiseAnd(op); |

266 | } else { |

267 | return Everything(); |

268 | } |

269 | } |

270 | |

271 | Entry VisitExpr_(const VarNode* op) final { |

272 | Var v = GetRef<Var>(op); |

273 | auto it = var_map_.find(v); |

274 | if (it != var_map_.end()) { |

275 | return it->second; |

276 | } else { |

277 | return Everything(); |

278 | } |

279 | } |

280 | |

281 | Entry VisitRightShift(const CallNode* op) { |

282 | Entry b = VisitExpr(op->args[1]); |

283 | // a c x / c -> a x |

284 | if (b.is_const()) { |

285 | return DivByConst(op->args[0], static_cast<int64_t>(1) << b.base, true); |

286 | } |

287 | return Everything(); |

288 | } |

289 | |

290 | Entry VisitBitwiseAnd(const CallNode* op) { |

291 | Entry b = VisitExpr(op->args[1]); |

292 | if (b.is_const()) { |

293 | int shift; |

294 | if (is_const_power_of_two_integer(Integer(b.base + 1), &shift)) { |

295 | return ModByConst(op->args[0], static_cast<int64_t>(1) << shift, true); |

296 | } |

297 | } |

298 | return Everything(); |

299 | } |

300 | |

301 | private: |

302 | /*! \brief pointer to parent. */ |

303 | Analyzer* parent_{nullptr}; |

304 | // internal variable map |

305 | std::unordered_map<Var, Entry, ObjectPtrHash, ObjectPtrEqual> var_map_; |

306 | /*! |

307 | * \brief Update var by intersecting entry with var's current set. |

308 | * \param var The variable. |

309 | * \param entry The entry to be updated. |

310 | * \return The recovery function of the scope. |

311 | */ |

312 | std::function<void()> UpdateByIntersect(const Var& var, Entry entry) { |

313 | Entry old = Everything(); |

314 | auto it = var_map_.find(var); |

315 | if (it != var_map_.end()) { |

316 | old = it->second; |

317 | } |

318 | var_map_[var] = Intersect(old, entry); |

319 | // reover function. |

320 | return [this, old, var]() { var_map_[var] = old; }; |

321 | } |

322 | /*! |

323 | * \brief Create union of two sets. |

324 | * \param a The left operand. |

325 | * \param b the right operand. |

326 | */ |

327 | static Entry Union(Entry a, Entry b) { |

328 | // {ax + y} \cup {bz + h} => {gcd(a, b) x + {y or h}} |

329 | int64_t coeff = ZeroAwareGCD(a.coeff, b.coeff); |

330 | if (coeff == 0) { |

331 | if (a.base == b.base) return a; |

332 | return Everything(); |

333 | } |

334 | int64_t base0 = a.base % coeff; |

335 | int64_t base1 = b.base % coeff; |

336 | if (base0 == base1) { |

337 | return Entry(coeff, base0); |

338 | } else { |

339 | return Entry(ZeroAwareGCD(ZeroAwareGCD(base0, base1), coeff), base0); |

340 | } |

341 | } |

342 | |

343 | /*! |

344 | * \brief Create interect of two sets. |

345 | * \param a The left operand. |

346 | * \param b the right operand. |

347 | */ |

348 | static Entry Intersect(Entry a, Entry b) { |

349 | int64_t x, y; |

350 | int64_t c1 = a.coeff, b1 = a.base, c2 = b.coeff, b2 = b.base; |

351 | // z = c1 * p + b1 |

352 | // z = c2 * q + b2 |

353 | // c1 * x + c2 * y = gcd(c1, c2) |

354 | // -> c1 * p - c2 * q = b2 - b1 |

355 | // -> p = (b2 - b1) / gcd * x |

356 | // -> q = (b2 - b1) / gcd * (-y) |

357 | // -> z = LCM(x, y) * k + (c1 * p + b1) |

358 | int64_t gcd = ExtendedEuclidean(c1, c2, &x, &y); |

359 | int64_t v = b2 - b1; |

360 | if (v % gcd == 0) { |

361 | x = v / gcd * x; |

362 | y = v / gcd * (-y); |

363 | int64_t coeff = c1 / gcd * c2; |

364 | return Entry(coeff, x * c1 + b1); |

365 | } else { |

366 | return Nothing(); |

367 | } |

368 | } |

369 | /*! |

370 | * \brief return everything dtype can represent. |

371 | * \return Bound that represent everything dtype can represent. |

372 | */ |

373 | static Entry Everything() { return Entry(1, 0); } |

374 | /*! |

375 | * \brief return an empty set |

376 | * \return Bound that represent everything dtype can represent. |

377 | */ |

378 | static Entry Nothing() { return Entry(0, 1); } |

379 | }; |

380 | |

381 | ModularSet ModularSetAnalyzer::operator()(const PrimExpr& expr) { |

382 | Entry ret = impl_->VisitExpr(expr); |

383 | return ModularSet(ret.coeff, ret.base); |

384 | } |

385 | |

386 | void ModularSetAnalyzer::Update(const Var& var, const ModularSet& info, bool allow_override) { |

387 | impl_->Update(var, info, allow_override); |

388 | } |

389 | |

390 | std::function<void()> ModularSetAnalyzer::EnterConstraint(const PrimExpr& constraint) { |

391 | return impl_->EnterConstraint(constraint); |

392 | } |

393 | |

394 | ModularSetAnalyzer::ModularSetAnalyzer(Analyzer* parent) : impl_(new Impl(parent)) {} |

395 | |

396 | ModularSetAnalyzer::~ModularSetAnalyzer() { delete impl_; } |

397 | |

398 | } // namespace arith |

399 | } // namespace tvm |

400 |