1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file narrow_predicate_expression.cc
22 * \brief Utility to deduce bound of expression
23 */
24#include <tvm/arith/int_solver.h>
25#include <tvm/runtime/registry.h>
26#include <tvm/tir/analysis.h>
27#include <tvm/tir/expr.h>
28#include <tvm/tir/op.h>
29#include <tvm/tir/stmt_functor.h>
30
31namespace tvm {
32namespace arith {
33
34using namespace tir;
35
36/* \brief Given a true expression that includes free parameter,
37 * generate a true expression without the free parameters.
38 *
39 * This function provides two guarantees:
40 *
41 * 1. If the resulting expression evaluates to True, then the original
42 * expression also evaluates to True.
43 *
44 * 2. The resulting expression does not contain any of the free
45 * parameters.
46 *
47 */
48// Utility for generating a known true expression from an expression
49// with free parameters, and the range of those parameters.
50class ExpressionNarrower : public tir::ExprMutator {
51 public:
52 static PrimExpr Apply(PrimExpr expr, Map<Var, Range> free_parameters) {
53 ICHECK(expr.dtype().is_bool()) << "Expected boolean expression, but received " << expr;
54 ExpressionNarrower mutator(free_parameters);
55 return mutator(expr);
56 }
57
58 private:
59 explicit ExpressionNarrower(Map<Var, Range> free_parameters)
60 : free_parameters_(free_parameters) {}
61
62 using Parent = tir::ExprMutator;
63 using Parent::VisitExpr_;
64
65 enum class Context {
66 Maximize,
67 Minimize,
68 };
69
70 template <typename T>
71 PrimExpr VisitInequality(T t, Context a_ctx, Context b_ctx) {
72 PrimExpr a = [&]() {
73 WithContext context(this, a_ctx);
74 return this->VisitExpr(t->a);
75 }();
76
77 PrimExpr b = [&]() {
78 WithContext context(this, b_ctx);
79 return this->VisitExpr(t->b);
80 }();
81
82 if (contains_unknown_expr_ && t.dtype().is_bool()) {
83 contains_unknown_expr_ = false;
84 return Bool(CurrentContext() == Context::Minimize);
85 } else if (a.same_as(t->a) && b.same_as(t->b)) {
86 return std::move(t);
87 } else {
88 return T(a, b);
89 }
90 }
91
92 PrimExpr VisitExpr_(const FloorModNode* op) override {
93 // FloorMod is non-monotonic, so inserting min/max won't remove
94 // the free parameters.
95 contains_unknown_expr_ = true;
96 return Parent::VisitExpr_(op);
97 }
98
99 PrimExpr VisitExpr_(const FloorDivNode* op) override {
100 auto res_a = this->VisitExpr(op->a);
101 auto res_b = this->VisitExpr(op->b);
102 if (is_zero(res_b)) {
103 contains_unknown_expr_ = true;
104 return IntImm(op->dtype, 0);
105 } else {
106 return floordiv(res_a, res_b);
107 }
108 }
109
110 PrimExpr VisitExpr_(const GTNode* op) override {
111 auto current = CurrentContext();
112 return VisitInequality(GetRef<GT>(op), OppositeContext(current), current);
113 }
114
115 PrimExpr VisitExpr_(const GENode* op) override {
116 auto current = CurrentContext();
117 return VisitInequality(GetRef<GE>(op), OppositeContext(current), current);
118 }
119
120 PrimExpr VisitExpr_(const LTNode* op) override {
121 auto current = CurrentContext();
122 return VisitInequality(GetRef<LT>(op), current, OppositeContext(current));
123 }
124
125 PrimExpr VisitExpr_(const LENode* op) override {
126 auto current = CurrentContext();
127 return VisitInequality(GetRef<LE>(op), current, OppositeContext(current));
128 }
129
130 PrimExpr VisitExpr_(const EQNode* op) override {
131 auto res_a = this->VisitExpr(op->a <= op->b);
132 auto res_b = this->VisitExpr(op->b <= op->a);
133 return res_a && res_b;
134 }
135
136 PrimExpr VisitExpr_(const NENode* op) override {
137 auto res_a = this->VisitExpr(op->a < op->b);
138 auto res_b = this->VisitExpr(op->b < op->a);
139 return res_a || res_b;
140 }
141
142 PrimExpr VisitExpr_(const SubNode* op) override {
143 auto current = CurrentContext();
144 return VisitInequality(GetRef<Sub>(op), current, OppositeContext(current));
145 }
146
147 PrimExpr VisitExpr_(const NotNode* op) override {
148 auto current = CurrentContext();
149 WithContext context(this, OppositeContext(current));
150 return !VisitExpr(op->a);
151 }
152
153 PrimExpr VisitExpr_(const BufferLoadNode* op) override {
154 contains_unknown_expr_ = true;
155 return GetRef<PrimExpr>(op);
156 }
157
158 PrimExpr VisitExpr_(const VarNode* op) override {
159 auto it = free_parameters_.find(GetRef<Var>(op));
160 if (it == free_parameters_.end()) {
161 return Parent::VisitExpr_(op);
162 }
163
164 Range range = (*it).second;
165
166 switch (CurrentContext()) {
167 case Context::Minimize:
168 return range->min;
169
170 case Context::Maximize:
171 return range->min + range->extent - 1;
172 }
173
174 return Parent::VisitExpr_(op);
175 }
176
177 Context CurrentContext() const {
178 if (context_stack_.size()) {
179 return context_stack_.back();
180 } else {
181 return Context::Maximize;
182 }
183 }
184
185 Context OppositeContext(Context context) const {
186 switch (context) {
187 case Context::Minimize:
188 return Context::Maximize;
189
190 case Context::Maximize:
191 return Context::Minimize;
192
193 default:
194 LOG(FATAL) << "Unhandled Context, all legal values should be handled";
195 }
196 }
197
198 struct WithContext {
199 WithContext(ExpressionNarrower* self, Context context) : self(self) {
200 self->context_stack_.push_back(context);
201 }
202 ~WithContext() { self->context_stack_.pop_back(); }
203 ExpressionNarrower* self;
204 };
205
206 std::vector<Context> context_stack_;
207 Map<Var, Range> free_parameters_;
208 bool contains_unknown_expr_{false};
209};
210
211PrimExpr NarrowPredicateExpression(PrimExpr expr, Map<Var, Range> free_parameters) {
212 return ExpressionNarrower::Apply(std::move(expr), std::move(free_parameters));
213}
214
215TVM_REGISTER_GLOBAL("arith.NarrowPredicateExpression").set_body_typed(NarrowPredicateExpression);
216
217} // namespace arith
218} // namespace tvm
219