1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file narrow_predicate_expression.cc |
22 | * \brief Utility to deduce bound of expression |
23 | */ |
24 | #include <tvm/arith/int_solver.h> |
25 | #include <tvm/runtime/registry.h> |
26 | #include <tvm/tir/analysis.h> |
27 | #include <tvm/tir/expr.h> |
28 | #include <tvm/tir/op.h> |
29 | #include <tvm/tir/stmt_functor.h> |
30 | |
31 | namespace tvm { |
32 | namespace arith { |
33 | |
34 | using namespace tir; |
35 | |
36 | /* \brief Given a true expression that includes free parameter, |
37 | * generate a true expression without the free parameters. |
38 | * |
39 | * This function provides two guarantees: |
40 | * |
41 | * 1. If the resulting expression evaluates to True, then the original |
42 | * expression also evaluates to True. |
43 | * |
44 | * 2. The resulting expression does not contain any of the free |
45 | * parameters. |
46 | * |
47 | */ |
48 | // Utility for generating a known true expression from an expression |
49 | // with free parameters, and the range of those parameters. |
50 | class ExpressionNarrower : public tir::ExprMutator { |
51 | public: |
52 | static PrimExpr Apply(PrimExpr expr, Map<Var, Range> free_parameters) { |
53 | ICHECK(expr.dtype().is_bool()) << "Expected boolean expression, but received " << expr; |
54 | ExpressionNarrower mutator(free_parameters); |
55 | return mutator(expr); |
56 | } |
57 | |
58 | private: |
59 | explicit ExpressionNarrower(Map<Var, Range> free_parameters) |
60 | : free_parameters_(free_parameters) {} |
61 | |
62 | using Parent = tir::ExprMutator; |
63 | using Parent::VisitExpr_; |
64 | |
65 | enum class Context { |
66 | Maximize, |
67 | Minimize, |
68 | }; |
69 | |
70 | template <typename T> |
71 | PrimExpr VisitInequality(T t, Context a_ctx, Context b_ctx) { |
72 | PrimExpr a = [&]() { |
73 | WithContext context(this, a_ctx); |
74 | return this->VisitExpr(t->a); |
75 | }(); |
76 | |
77 | PrimExpr b = [&]() { |
78 | WithContext context(this, b_ctx); |
79 | return this->VisitExpr(t->b); |
80 | }(); |
81 | |
82 | if (contains_unknown_expr_ && t.dtype().is_bool()) { |
83 | contains_unknown_expr_ = false; |
84 | return Bool(CurrentContext() == Context::Minimize); |
85 | } else if (a.same_as(t->a) && b.same_as(t->b)) { |
86 | return std::move(t); |
87 | } else { |
88 | return T(a, b); |
89 | } |
90 | } |
91 | |
92 | PrimExpr VisitExpr_(const FloorModNode* op) override { |
93 | // FloorMod is non-monotonic, so inserting min/max won't remove |
94 | // the free parameters. |
95 | contains_unknown_expr_ = true; |
96 | return Parent::VisitExpr_(op); |
97 | } |
98 | |
99 | PrimExpr VisitExpr_(const FloorDivNode* op) override { |
100 | auto res_a = this->VisitExpr(op->a); |
101 | auto res_b = this->VisitExpr(op->b); |
102 | if (is_zero(res_b)) { |
103 | contains_unknown_expr_ = true; |
104 | return IntImm(op->dtype, 0); |
105 | } else { |
106 | return floordiv(res_a, res_b); |
107 | } |
108 | } |
109 | |
110 | PrimExpr VisitExpr_(const GTNode* op) override { |
111 | auto current = CurrentContext(); |
112 | return VisitInequality(GetRef<GT>(op), OppositeContext(current), current); |
113 | } |
114 | |
115 | PrimExpr VisitExpr_(const GENode* op) override { |
116 | auto current = CurrentContext(); |
117 | return VisitInequality(GetRef<GE>(op), OppositeContext(current), current); |
118 | } |
119 | |
120 | PrimExpr VisitExpr_(const LTNode* op) override { |
121 | auto current = CurrentContext(); |
122 | return VisitInequality(GetRef<LT>(op), current, OppositeContext(current)); |
123 | } |
124 | |
125 | PrimExpr VisitExpr_(const LENode* op) override { |
126 | auto current = CurrentContext(); |
127 | return VisitInequality(GetRef<LE>(op), current, OppositeContext(current)); |
128 | } |
129 | |
130 | PrimExpr VisitExpr_(const EQNode* op) override { |
131 | auto res_a = this->VisitExpr(op->a <= op->b); |
132 | auto res_b = this->VisitExpr(op->b <= op->a); |
133 | return res_a && res_b; |
134 | } |
135 | |
136 | PrimExpr VisitExpr_(const NENode* op) override { |
137 | auto res_a = this->VisitExpr(op->a < op->b); |
138 | auto res_b = this->VisitExpr(op->b < op->a); |
139 | return res_a || res_b; |
140 | } |
141 | |
142 | PrimExpr VisitExpr_(const SubNode* op) override { |
143 | auto current = CurrentContext(); |
144 | return VisitInequality(GetRef<Sub>(op), current, OppositeContext(current)); |
145 | } |
146 | |
147 | PrimExpr VisitExpr_(const NotNode* op) override { |
148 | auto current = CurrentContext(); |
149 | WithContext context(this, OppositeContext(current)); |
150 | return !VisitExpr(op->a); |
151 | } |
152 | |
153 | PrimExpr VisitExpr_(const BufferLoadNode* op) override { |
154 | contains_unknown_expr_ = true; |
155 | return GetRef<PrimExpr>(op); |
156 | } |
157 | |
158 | PrimExpr VisitExpr_(const VarNode* op) override { |
159 | auto it = free_parameters_.find(GetRef<Var>(op)); |
160 | if (it == free_parameters_.end()) { |
161 | return Parent::VisitExpr_(op); |
162 | } |
163 | |
164 | Range range = (*it).second; |
165 | |
166 | switch (CurrentContext()) { |
167 | case Context::Minimize: |
168 | return range->min; |
169 | |
170 | case Context::Maximize: |
171 | return range->min + range->extent - 1; |
172 | } |
173 | |
174 | return Parent::VisitExpr_(op); |
175 | } |
176 | |
177 | Context CurrentContext() const { |
178 | if (context_stack_.size()) { |
179 | return context_stack_.back(); |
180 | } else { |
181 | return Context::Maximize; |
182 | } |
183 | } |
184 | |
185 | Context OppositeContext(Context context) const { |
186 | switch (context) { |
187 | case Context::Minimize: |
188 | return Context::Maximize; |
189 | |
190 | case Context::Maximize: |
191 | return Context::Minimize; |
192 | |
193 | default: |
194 | LOG(FATAL) << "Unhandled Context, all legal values should be handled" ; |
195 | } |
196 | } |
197 | |
198 | struct WithContext { |
199 | WithContext(ExpressionNarrower* self, Context context) : self(self) { |
200 | self->context_stack_.push_back(context); |
201 | } |
202 | ~WithContext() { self->context_stack_.pop_back(); } |
203 | ExpressionNarrower* self; |
204 | }; |
205 | |
206 | std::vector<Context> context_stack_; |
207 | Map<Var, Range> free_parameters_; |
208 | bool contains_unknown_expr_{false}; |
209 | }; |
210 | |
211 | PrimExpr NarrowPredicateExpression(PrimExpr expr, Map<Var, Range> free_parameters) { |
212 | return ExpressionNarrower::Apply(std::move(expr), std::move(free_parameters)); |
213 | } |
214 | |
215 | TVM_REGISTER_GLOBAL("arith.NarrowPredicateExpression" ).set_body_typed(NarrowPredicateExpression); |
216 | |
217 | } // namespace arith |
218 | } // namespace tvm |
219 | |