19 | |

20 | /*! |

21 | * \file narrow_predicate_expression.cc |

22 | * \brief Utility to deduce bound of expression |

23 | */ |

24 | #include <tvm/arith/int_solver.h> |

25 | #include <tvm/runtime/registry.h> |

26 | #include <tvm/tir/analysis.h> |

27 | #include <tvm/tir/expr.h> |

28 | #include <tvm/tir/op.h> |

29 | #include <tvm/tir/stmt_functor.h> |

30 | |

31 | namespace tvm { |

32 | namespace arith { |

33 | |

34 | using namespace tir; |

35 | |

36 | /* \brief Given a true expression that includes free parameter, |

37 | * generate a true expression without the free parameters. |

38 | * |

39 | * This function provides two guarantees: |

40 | * |

41 | * 1. If the resulting expression evaluates to True, then the original |

42 | * expression also evaluates to True. |

43 | * |

44 | * 2. The resulting expression does not contain any of the free |

45 | * parameters. |

46 | * |

47 | */ |

48 | // Utility for generating a known true expression from an expression |

49 | // with free parameters, and the range of those parameters. |

50 | class ExpressionNarrower : public tir::ExprMutator { |

51 | public: |

52 | static PrimExpr Apply(PrimExpr expr, Map<Var, Range> free_parameters) { |

53 | ICHECK(expr.dtype().is_bool()) << "Expected boolean expression, but received "<< expr; |

54 | ExpressionNarrower mutator(free_parameters); |

55 | return mutator(expr); |

56 | } |

57 | |

58 | private: |

59 | explicit ExpressionNarrower(Map<Var, Range> free_parameters) |

60 | : free_parameters_(free_parameters) {} |

61 | |

62 | using Parent = tir::ExprMutator; |

63 | using Parent::VisitExpr_; |

64 | |

65 | enum class Context { |

66 | Maximize, |

67 | Minimize, |

68 | }; |

69 | |

70 | template <typename T> |

71 | PrimExpr VisitInequality(T t, Context a_ctx, Context b_ctx) { |

72 | PrimExpr a = [&]() { |

73 | WithContext context(this, a_ctx); |

74 | return this->VisitExpr(t->a); |

75 | }(); |

76 | |

77 | PrimExpr b = [&]() { |

78 | WithContext context(this, b_ctx); |

79 | return this->VisitExpr(t->b); |

80 | }(); |

81 | |

82 | if (contains_unknown_expr_ && t.dtype().is_bool()) { |

83 | contains_unknown_expr_ = false; |

84 | return Bool(CurrentContext() == Context::Minimize); |

85 | } else if (a.same_as(t->a) && b.same_as(t->b)) { |

86 | return std::move(t); |

87 | } else { |

88 | return T(a, b); |

89 | } |

90 | } |

91 | |

92 | PrimExpr VisitExpr_(const FloorModNode* op) override { |

93 | // FloorMod is non-monotonic, so inserting min/max won't remove |

94 | // the free parameters. |

95 | contains_unknown_expr_ = true; |

96 | return Parent::VisitExpr_(op); |

97 | } |

98 | |

99 | PrimExpr VisitExpr_(const FloorDivNode* op) override { |

100 | auto res_a = this->VisitExpr(op->a); |

101 | auto res_b = this->VisitExpr(op->b); |

102 | if (is_zero(res_b)) { |

103 | contains_unknown_expr_ = true; |

104 | return IntImm(op->dtype, 0); |

105 | } else { |

106 | return floordiv(res_a, res_b); |

107 | } |

108 | } |

109 | |

110 | PrimExpr VisitExpr_(const GTNode* op) override { |

111 | auto current = CurrentContext(); |

112 | return VisitInequality(GetRef<GT>(op), OppositeContext(current), current); |

113 | } |

114 | |

115 | PrimExpr VisitExpr_(const GENode* op) override { |

116 | auto current = CurrentContext(); |

117 | return VisitInequality(GetRef<GE>(op), OppositeContext(current), current); |

118 | } |

119 | |

120 | PrimExpr VisitExpr_(const LTNode* op) override { |

121 | auto current = CurrentContext(); |

122 | return VisitInequality(GetRef<LT>(op), current, OppositeContext(current)); |

123 | } |

124 | |

125 | PrimExpr VisitExpr_(const LENode* op) override { |

126 | auto current = CurrentContext(); |

127 | return VisitInequality(GetRef<LE>(op), current, OppositeContext(current)); |

128 | } |

129 | |

130 | PrimExpr VisitExpr_(const EQNode* op) override { |

131 | auto res_a = this->VisitExpr(op->a <= op->b); |

132 | auto res_b = this->VisitExpr(op->b <= op->a); |

133 | return res_a && res_b; |

134 | } |

135 | |

136 | PrimExpr VisitExpr_(const NENode* op) override { |

137 | auto res_a = this->VisitExpr(op->a < op->b); |

138 | auto res_b = this->VisitExpr(op->b < op->a); |

139 | return res_a || res_b; |

140 | } |

141 | |

142 | PrimExpr VisitExpr_(const SubNode* op) override { |

143 | auto current = CurrentContext(); |

144 | return VisitInequality(GetRef<Sub>(op), current, OppositeContext(current)); |

145 | } |

146 | |

147 | PrimExpr VisitExpr_(const NotNode* op) override { |

148 | auto current = CurrentContext(); |

149 | WithContext context(this, OppositeContext(current)); |

150 | return !VisitExpr(op->a); |

151 | } |

152 | |

153 | PrimExpr VisitExpr_(const BufferLoadNode* op) override { |

154 | contains_unknown_expr_ = true; |

155 | return GetRef<PrimExpr>(op); |

156 | } |

157 | |

158 | PrimExpr VisitExpr_(const VarNode* op) override { |

159 | auto it = free_parameters_.find(GetRef<Var>(op)); |

160 | if (it == free_parameters_.end()) { |

161 | return Parent::VisitExpr_(op); |

162 | } |

163 | |

164 | Range range = (*it).second; |

165 | |

166 | switch (CurrentContext()) { |

167 | case Context::Minimize: |

168 | return range->min; |

169 | |

170 | case Context::Maximize: |

171 | return range->min + range->extent - 1; |

172 | } |

173 | |

174 | return Parent::VisitExpr_(op); |

175 | } |

176 | |

177 | Context CurrentContext() const { |

178 | if (context_stack_.size()) { |

179 | return context_stack_.back(); |

180 | } else { |

181 | return Context::Maximize; |

182 | } |

183 | } |

184 | |

185 | Context OppositeContext(Context context) const { |

186 | switch (context) { |

187 | case Context::Minimize: |

188 | return Context::Maximize; |

189 | |

190 | case Context::Maximize: |

191 | return Context::Minimize; |

192 | |

193 | default: |

194 | LOG(FATAL) << "Unhandled Context, all legal values should be handled"; |

195 | } |

196 | } |

197 | |

198 | struct WithContext { |

199 | WithContext(ExpressionNarrower* self, Context context) : self(self) { |

200 | self->context_stack_.push_back(context); |

201 | } |

202 | ~WithContext() { self->context_stack_.pop_back(); } |

203 | ExpressionNarrower* self; |

204 | }; |

205 | |

206 | std::vector<Context> context_stack_; |

207 | Map<Var, Range> free_parameters_; |

208 | bool contains_unknown_expr_{false}; |

209 | }; |

210 | |

211 | PrimExpr NarrowPredicateExpression(PrimExpr expr, Map<Var, Range> free_parameters) { |

212 | return ExpressionNarrower::Apply(std::move(expr), std::move(free_parameters)); |

213 | } |

214 | |

215 | TVM_REGISTER_GLOBAL("arith.NarrowPredicateExpression").set_body_typed(NarrowPredicateExpression); |

216 | |

217 | } // namespace arith |

218 | } // namespace tvm |

219 |