1 | /* |
---|---|

2 | * Licensed to the Apache Software Foundation (ASF) under one |

3 | * or more contributor license agreements. See the NOTICE file |

4 | * distributed with this work for additional information |

5 | * regarding copyright ownership. The ASF licenses this file |

6 | * to you under the Apache License, Version 2.0 (the |

7 | * "License"); you may not use this file except in compliance |

8 | * with the License. You may obtain a copy of the License at |

9 | * |

10 | * http://www.apache.org/licenses/LICENSE-2.0 |

11 | * |

12 | * Unless required by applicable law or agreed to in writing, |

13 | * software distributed under the License is distributed on an |

14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |

15 | * KIND, either express or implied. See the License for the |

16 | * specific language governing permissions and limitations |

17 | * under the License. |

18 | */ |

19 | |

20 | /*! |

21 | * \file bound_deducer.cc |

22 | * \brief Utility to deduce bound of expression |

23 | */ |

24 | #include <tvm/arith/analyzer.h> |

25 | #include <tvm/runtime/registry.h> |

26 | #include <tvm/tir/expr.h> |

27 | #include <tvm/tir/expr_functor.h> |

28 | |

29 | #include <unordered_map> |

30 | #include <unordered_set> |

31 | |

32 | #include "interval_set.h" |

33 | |

34 | namespace tvm { |

35 | namespace arith { |

36 | |

37 | using namespace tir; |

38 | |

39 | // a visitor to find the path to the target variable |

40 | // from a expression. |

41 | class VariablePathFinder : public ExprVisitor { |

42 | public: |

43 | explicit VariablePathFinder(PrimExpr target) : target_(target) {} |

44 | |

45 | void VisitExpr(const PrimExpr& node) final { |

46 | if (visited_.count(node.get()) != 0) return; |

47 | visited_.insert(node.get()); |

48 | |

49 | if (!found_) path_.push_back(node.get()); |

50 | if (node.same_as(target_)) found_ = true; |

51 | ExprVisitor::VisitExpr(node); |

52 | if (!found_) path_.pop_back(); |

53 | } |

54 | |

55 | std::vector<const Object*> path_; |

56 | |

57 | private: |

58 | bool found_{false}; |

59 | PrimExpr target_; |

60 | std::unordered_set<const Object*> visited_; |

61 | }; |

62 | |

63 | // get the path to the variable, |

64 | // return empty vector to represent failure |

65 | std::vector<const Object*> GetPath(PrimExpr target, PrimExpr expr) { |

66 | VariablePathFinder v(target); |

67 | v(expr); |

68 | return v.path_; |

69 | } |

70 | |

71 | enum CompareOp { kGreater, kLess, kEqual }; |

72 | |

73 | // a visitor to deduce the bound of a variable from a expression |

74 | class BoundDeducer : public ExprFunctor<void(const PrimExpr&)> { |

75 | public: |

76 | friend class BoundDeduceInputChecker; |

77 | friend class Converter; |

78 | BoundDeducer(PrimExpr target, PrimExpr expr, |

79 | const std::unordered_map<const VarNode*, IntSet>& hint_map, |

80 | const std::unordered_map<const VarNode*, IntSet>& relax_map) |

81 | : target_(target), expr_(expr), hint_map_(hint_map), relax_map_(relax_map) {} |

82 | |

83 | void Deduce(); |

84 | |

85 | void VisitExpr(const PrimExpr& e) final { |

86 | if (!success_) return; |

87 | if (iter_ < path_.size() && e.get() == path_[iter_++]) { |

88 | ExprFunctor::VisitExpr(e); |

89 | } else { |

90 | success_ = false; |

91 | return; |

92 | } |

93 | } |

94 | |

95 | void VisitExprDefault_(const Object* op) final { success_ = false; } |

96 | |

97 | SignType GetSignType(const PrimExpr& e) { |

98 | if (e.dtype().is_uint()) { |

99 | return kPositive; |

100 | } |

101 | return expr_map_[e].GetSignType(); |

102 | } |

103 | |

104 | void VisitExpr_(const VarNode* op) final {} |

105 | |

106 | void VisitExpr_(const AddNode* op) final { |

107 | bool left = op->a.get() == path_[iter_]; |

108 | result_ -= left ? op->b : op->a; |

109 | this->VisitExpr(left ? op->a : op->b); |

110 | } |

111 | |

112 | void VisitExpr_(const SubNode* op) final { |

113 | bool left = op->a.get() == path_[iter_]; |

114 | if (left) { |

115 | result_ += op->b; |

116 | } else { |

117 | result_ -= op->a; |

118 | result_ = -result_; |

119 | comp_op = ReverseOp(comp_op); |

120 | } |

121 | this->VisitExpr(left ? op->a : op->b); |

122 | } |

123 | |

124 | void VisitExpr_(const MulNode* op) final { |

125 | bool left = op->a.get() == path_[iter_]; |

126 | PrimExpr operand = left ? op->b : op->a; |

127 | PrimExpr target_var = left ? op->a : op->b; |

128 | |

129 | SignType sign_operand = GetSignType(operand); |

130 | if (sign_operand == SignType::kNegative) { |

131 | comp_op = ReverseOp(comp_op); |

132 | } else if (sign_operand == SignType::kUnknown) { |

133 | // unable to get the sign of operand |

134 | success_ = false; |

135 | return; |

136 | } |

137 | |

138 | // always use relax bound |

139 | bool divided = analyzer_.CanProve(floormod(result_, operand) == 0); |

140 | |

141 | result_ = floordiv(result_, operand); // rounding down here |

142 | |

143 | if (!divided) { |

144 | if (comp_op == kGreater) { |

145 | // System will round down in all the cases, so add one for result_ for kGreater |

146 | // (x >= 3/2 --> x >= 2) |

147 | // (x >= -3/2 --> x >= -1) |

148 | // (x >= 3/-2 --> x >= -1) |

149 | // (x >= -3/-2 --> x >= 2) |

150 | result_ += 1; |

151 | } else if (comp_op == kEqual) { |

152 | // condition unsatisfiable as with floor div, it will change the expression |

153 | success_ = false; |

154 | return; |

155 | } else { |

156 | // System rounds down in all cases, do nothing for kLess. |

157 | // ( x <= 3/2 --> x <= 1) |

158 | // ( x <= -3/2 --> x <= -2) |

159 | // ( x <= 3/-2 --> x <= -2) |

160 | // ( x <= -3/-2 --> x <= 1) |

161 | } |

162 | } |

163 | this->VisitExpr(left ? op->a : op->b); |

164 | } |

165 | |

166 | void VisitExpr_(const FloorDivNode* op) final { |

167 | if (op->b.get() == path_[iter_]) { |

168 | // Skip cases where the var is divisor. |

169 | success_ = false; |

170 | return; |

171 | } |

172 | PrimExpr divisor = op->b; |

173 | if (analyzer_.CanProveEqual(divisor, 0)) { |

174 | // Skip zero divisor |

175 | success_ = false; |

176 | return; |

177 | } |

178 | |

179 | SignType sign_operand = GetSignType(divisor); |

180 | if (sign_operand == SignType::kNegative) { |

181 | comp_op = ReverseOp(comp_op); |

182 | divisor = -divisor; |

183 | result_ = -result_; |

184 | } else if (sign_operand == SignType::kUnknown) { |

185 | // unable to get the sign of operand |

186 | success_ = false; |

187 | return; |

188 | } |

189 | |

190 | if (comp_op == kGreater) { |

191 | // (x // 6 >= 4 --> x >= 4 * 6) |

192 | result_ = result_ * divisor; |

193 | } else if (comp_op == kEqual) { |

194 | // The bound is not single directional |

195 | // (x // 6 == 4 --> 30 > x >= 24) |

196 | // TODO(@wrongtest): support bidirectional bound |

197 | success_ = false; |

198 | return; |

199 | } else { |

200 | // (x // 6 <= 4 --> x <= 4 * 6 + 5) |

201 | result_ = result_ * divisor + divisor - 1; |

202 | } |

203 | if (sign_operand == SignType::kNegative) { |

204 | // (x // -6 >= 4 --> -((x + 6 - 1) // 6) >= 4 |

205 | // --> (x + 6 - 1) // 6 <= -4 |

206 | result_ = result_ - divisor + 1; |

207 | } |

208 | |

209 | this->VisitExpr(op->a); |

210 | } |

211 | |

212 | PrimExpr result_; |

213 | CompareOp comp_op{kGreater}; |

214 | bool success_{true}; |

215 | |

216 | private: |

217 | void Init(); |

218 | void Transform(); |

219 | void Relax(); |

220 | CompareOp ReverseOp(CompareOp comp_op); |

221 | PrimExpr target_; |

222 | PrimExpr expr_; |

223 | const std::unordered_map<const VarNode*, IntSet>& hint_map_; |

224 | const std::unordered_map<const VarNode*, IntSet>& relax_map_; |

225 | ExprIntSetMap expr_map_; |

226 | std::vector<const Object*> path_; |

227 | size_t iter_{0}; |

228 | // internal analzyer |

229 | Analyzer analyzer_; |

230 | }; |

231 | |

232 | class BoundDeduceInputChecker : public ExprVisitor { |

233 | public: |

234 | bool Check(BoundDeducer* deducer) { |

235 | deducer_ = deducer; |

236 | this->VisitExpr(deducer_->expr_); |

237 | return target_count == 1; |

238 | } |

239 | |

240 | void VisitExpr(const PrimExpr& e) final { |

241 | if (e.same_as(deducer_->target_)) ++target_count; |

242 | ExprVisitor::VisitExpr(e); |

243 | } |

244 | |

245 | private: |

246 | BoundDeducer* deducer_; |

247 | size_t target_count{0}; |

248 | }; |

249 | |

250 | void BoundDeducer::Init() { |

251 | BoundDeduceInputChecker checker; |

252 | if (!checker.Check(this)) success_ = false; |

253 | Transform(); |

254 | } |

255 | |

256 | CompareOp BoundDeducer::ReverseOp(CompareOp comp_op) { |

257 | switch (comp_op) { |

258 | case kEqual: |

259 | return kEqual; // IntSet can not represent range for `NE |

260 | case kGreater: |

261 | return kLess; |

262 | case kLess: |

263 | return kGreater; |

264 | default: |

265 | LOG(FATAL) << "Not a valid compare op"; |

266 | } |

267 | } |

268 | |

269 | void BoundDeducer::Transform() { |

270 | // We will ensure to set expr_ such that it contains target_ |

271 | if (const LTNode* op = expr_.as<LTNode>()) { |

272 | if (GetPath(target_, op->a).empty()) { |

273 | // a < b -> b >= a + 1 |

274 | comp_op = kGreater; |

275 | expr_ = op->b; |

276 | result_ = op->a + 1; |

277 | } else { |

278 | // a < b -> a <= b - 1 |

279 | comp_op = kLess; |

280 | expr_ = op->a; |

281 | result_ = op->b - 1; |

282 | } |

283 | } else if (const LENode* op = expr_.as<LENode>()) { |

284 | if (GetPath(target_, op->a).empty()) { |

285 | // a <= b -> b >= a |

286 | comp_op = kGreater; |

287 | expr_ = op->b; |

288 | result_ = op->a; |

289 | } else { |

290 | comp_op = kLess; |

291 | expr_ = op->a; |

292 | result_ = op->b; |

293 | } |

294 | } else if (const GTNode* op = expr_.as<GTNode>()) { |

295 | if (GetPath(target_, op->a).empty()) { |

296 | // a > b -> b <= a - 1 |

297 | comp_op = kLess; |

298 | expr_ = op->b; |

299 | result_ = op->a - 1; |

300 | } else { |

301 | // a > b -> a >= b + 1 |

302 | comp_op = kGreater; |

303 | expr_ = op->a; |

304 | result_ = op->b + 1; |

305 | } |

306 | } else if (const GENode* op = expr_.as<GENode>()) { |

307 | if (GetPath(target_, op->a).empty()) { |

308 | // a >= b -> b <= a |

309 | comp_op = kLess; |

310 | expr_ = op->b; |

311 | result_ = op->a; |

312 | } else { |

313 | comp_op = kGreater; |

314 | expr_ = op->a; |

315 | result_ = op->b; |

316 | } |

317 | } else if (const EQNode* op = expr_.as<EQNode>()) { |

318 | comp_op = kEqual; |

319 | if (GetPath(target_, op->a).empty()) { |

320 | // if the b == a -> a == b |

321 | expr_ = op->b; |

322 | result_ = op->a; |

323 | } else { |

324 | expr_ = op->a; |

325 | result_ = op->b; |

326 | } |

327 | } else { |

328 | success_ = false; |

329 | } |

330 | } |

331 | |

332 | void BoundDeducer::Deduce() { |

333 | Init(); |

334 | if (!success_) return; |

335 | |

336 | Relax(); |

337 | if (!success_) return; |

338 | // get the path |

339 | path_ = GetPath(target_, expr_); |

340 | if (!path_.size()) { |

341 | success_ = false; |

342 | return; |

343 | } |

344 | expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_); |

345 | |

346 | this->VisitExpr(expr_); |

347 | } |

348 | |

349 | void BoundDeducer::Relax() { |

350 | IntSet a = EvalSet(expr_, relax_map_); |

351 | IntSet b = EvalSet(result_, relax_map_); |

352 | if (a.IsEverything() || b.IsEverything()) { |

353 | success_ = false; |

354 | return; |

355 | } |

356 | // Both LHS and RHS of the EQ should behave as constants e.g. i == j, |

357 | // can not be resolved when either `i` or `j` or both are variables with |

358 | // some Range OR `i` and `j` both should be a single point in IntSet |

359 | if (comp_op == kEqual && |

360 | (!analyzer_.CanProve(b.min() == b.max()) || !analyzer_.CanProve(a.min() == a.max()))) { |

361 | success_ = false; |

362 | return; |

363 | } |

364 | expr_ = (comp_op == kGreater) ? a.min() : a.max(); |

365 | result_ = (comp_op == kGreater) ? b.max() : b.min(); |

366 | } |

367 | |

368 | IntSet DeduceBound(PrimExpr v, PrimExpr e, |

369 | const std::unordered_map<const VarNode*, IntSet>& hint_map, |

370 | const std::unordered_map<const VarNode*, IntSet>& relax_map) { |

371 | BoundDeducer d(v, e, hint_map, relax_map); |

372 | d.Deduce(); |

373 | if (!d.success_) return IntSet::Nothing(); |

374 | PrimExpr min = neg_inf(), max = pos_inf(); |

375 | if (d.comp_op == kEqual) { |

376 | min = d.result_; |

377 | max = d.result_; |

378 | } else if (d.comp_op == kGreater) { |

379 | min = d.result_; |

380 | } else { |

381 | max = d.result_; |

382 | } |

383 | return IntSet::Interval(min, max); |

384 | } |

385 | |

386 | // assuming e >= 0, deduce the bound of variable from it. |

387 | // return empty set to represent deduce failure. |

388 | IntSet DeduceBound(PrimExpr v, PrimExpr e, const Map<Var, IntSet>& hint_map, |

389 | const Map<Var, IntSet>& relax_map) { |

390 | std::unordered_map<const VarNode*, IntSet> hmap; |

391 | for (auto kv : hint_map) { |

392 | hmap[kv.first.get()] = kv.second; |

393 | } |

394 | std::unordered_map<const VarNode*, IntSet> rmap; |

395 | for (auto kv : relax_map) { |

396 | rmap[kv.first.get()] = kv.second; |

397 | } |

398 | return DeduceBound(v, e, hmap, rmap); |

399 | } |

400 | |

401 | TVM_REGISTER_GLOBAL("arith.DeduceBound") |

402 | .set_body_typed([](PrimExpr v, PrimExpr cond, const Map<Var, IntSet> hint_map, |

403 | const Map<Var, IntSet> relax_map) { |

404 | return DeduceBound(v, cond, hint_map, relax_map); |

405 | }); |

406 | |

407 | } // namespace arith |

408 | } // namespace tvm |

409 |