1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file int_constraints.cc |
22 | * \brief The integer constraints data structures. |
23 | */ |
24 | #include <tvm/arith/analyzer.h> |
25 | #include <tvm/arith/int_solver.h> |
26 | #include <tvm/runtime/registry.h> |
27 | #include <tvm/tir/expr.h> |
28 | #include <tvm/tir/expr_functor.h> |
29 | #include <tvm/tir/op.h> |
30 | #include <tvm/tir/stmt_functor.h> |
31 | |
32 | #include <algorithm> |
33 | #include <unordered_map> |
34 | #include <utility> |
35 | |
36 | #include "../tir/transforms/ir_utils.h" |
37 | |
38 | namespace tvm { |
39 | namespace arith { |
40 | |
41 | Array<PrimExpr> AsConditions(const Array<Var>& variables, const Map<Var, IntGroupBounds>& bounds, |
42 | const Array<PrimExpr>& relations) { |
43 | Array<PrimExpr> res; |
44 | // use variables to keep the order of iteration |
45 | // so as to get rid of any non-determinism. |
46 | ICHECK_EQ(variables.size(), bounds.size()); |
47 | for (const auto v : variables) { |
48 | ICHECK(bounds.count(v)); |
49 | const auto& bnds = bounds[v]; |
50 | PrimExpr lhs = bnds->coef * v; |
51 | for (const PrimExpr& rhs : bnds->equal) { |
52 | res.push_back(lhs == rhs); |
53 | } |
54 | for (const PrimExpr& rhs : bnds->lower) { |
55 | res.push_back(lhs >= rhs); |
56 | } |
57 | for (const PrimExpr& rhs : bnds->upper) { |
58 | res.push_back(lhs <= rhs); |
59 | } |
60 | } |
61 | for (const PrimExpr& e : relations) { |
62 | res.push_back(e); |
63 | } |
64 | return res; |
65 | } |
66 | |
67 | IntGroupBounds::IntGroupBounds(PrimExpr coef, Array<PrimExpr> lower, Array<PrimExpr> equal, |
68 | Array<PrimExpr> upper) { |
69 | ICHECK(coef.dtype().is_int() || coef.dtype().is_uint()) |
70 | << "Coefficient in IntGroupBounds must be integers" ; |
71 | ObjectPtr<IntGroupBoundsNode> node = make_object<IntGroupBoundsNode>(); |
72 | node->coef = std::move(coef); |
73 | node->lower = std::move(lower); |
74 | node->equal = std::move(equal); |
75 | node->upper = std::move(upper); |
76 | data_ = std::move(node); |
77 | } |
78 | |
79 | IntGroupBounds IntGroupBounds::FromRange(const Range& r) { |
80 | Analyzer analyzer; |
81 | PrimExpr coef = tir::make_const(r->min.dtype(), 1); |
82 | Array<PrimExpr> equal; |
83 | Array<PrimExpr> lower; |
84 | Array<PrimExpr> upper; |
85 | if (tir::is_one(r->extent)) { |
86 | equal.push_back(r->min); |
87 | } else { |
88 | lower.push_back(r->min); |
89 | upper.push_back(analyzer.Simplify(r->min + r->extent - 1)); |
90 | } |
91 | return IntGroupBounds(coef, lower, equal, upper); |
92 | } |
93 | |
94 | IntGroupBounds IntGroupBounds::operator+(const Range& r) { |
95 | Analyzer analyzer; |
96 | Array<PrimExpr> equal; |
97 | Array<PrimExpr> lower; |
98 | Array<PrimExpr> upper; |
99 | const PrimExpr& coef = operator->()->coef; |
100 | if (tir::is_one(r->extent)) { |
101 | equal.push_back(analyzer.Simplify(r->min * coef)); |
102 | } else { |
103 | lower.push_back(analyzer.Simplify(r->min * coef)); |
104 | upper.push_back(analyzer.Simplify((r->min + r->extent - 1) * coef)); |
105 | } |
106 | for (const auto& eq : operator->()->equal) equal.push_back(eq); |
107 | for (const auto& lb : operator->()->lower) lower.push_back(lb); |
108 | for (const auto& ub : operator->()->upper) upper.push_back(ub); |
109 | return IntGroupBounds(coef, lower, equal, upper); |
110 | } |
111 | |
112 | IntGroupBounds IntGroupBounds::Substitute(const Map<Var, PrimExpr>& subst) const { |
113 | auto apply_fun = [&subst](const PrimExpr& e) { return tir::Substitute(e, subst); }; |
114 | return IntGroupBounds(tir::Substitute(operator->()->coef, subst), |
115 | tir::UpdateArray(operator->()->lower, apply_fun), |
116 | tir::UpdateArray(operator->()->equal, apply_fun), |
117 | tir::UpdateArray(operator->()->upper, apply_fun)); |
118 | } |
119 | |
120 | Range IntGroupBounds::FindBestRange(const Map<Var, Range>& vranges_addl) const { |
121 | Analyzer analyzer; |
122 | analyzer.Bind(vranges_addl); |
123 | |
124 | std::unordered_map<const VarNode*, IntSet> var_intsets; |
125 | for (auto kv : vranges_addl) { |
126 | var_intsets[kv.first.get()] = IntSet::FromRange(kv.second); |
127 | } |
128 | |
129 | const Array<PrimExpr>& equal = operator->()->equal; |
130 | const PrimExpr& coef = operator->()->coef; |
131 | |
132 | std::vector<PrimExpr> lowers(equal.begin(), equal.end()); |
133 | std::vector<PrimExpr> uppers(equal.begin(), equal.end()); |
134 | for (const auto& expr : operator->()->lower) { |
135 | lowers.push_back(expr); |
136 | } |
137 | for (const auto& expr : operator->()->upper) { |
138 | uppers.push_back(expr); |
139 | } |
140 | |
141 | if (lowers.size() == 1 && uppers.size() == 1 && tir::is_one(coef)) { |
142 | return Range(analyzer.Simplify(lowers[0]), analyzer.Simplify(uppers[0] + 1)); |
143 | } |
144 | |
145 | // Here we will try all pairs of lower and upper bounds and find the best pair, that is, the |
146 | // pair with the minimal difference between the upper and the lower. |
147 | // Note that the bounds are for v, not for v*coef |
148 | |
149 | // The lower bound of the best pair so far |
150 | PrimExpr best_lower; |
151 | // The difference between the upper and the lower of the best pair, maybe overapproximation |
152 | PrimExpr best_diff_over; |
153 | |
154 | for (const PrimExpr& low : lowers) { |
155 | for (const PrimExpr& upp : uppers) { |
156 | PrimExpr diff_1 = analyzer.Simplify(floordiv(upp - low, coef), 3); |
157 | // Since diff may depend on some other variables, we compute its overapproximation |
158 | PrimExpr diff_over_1 = analyzer.Simplify(EvalSet(diff_1, var_intsets).max(), 3); |
159 | |
160 | // low is the lower bound for v*coef, but we need the lower bound for v. |
161 | // We use rounding-up division to compute it. Since we want to use a single formula |
162 | PrimExpr low_divided = analyzer.Simplify(floordiv(low + coef - 1, coef), 3); |
163 | |
164 | // Compute another difference which may be more precise (or not). |
165 | PrimExpr diff_2 = analyzer.Simplify(floordiv(upp, coef) - low_divided, 3); |
166 | PrimExpr diff_over_2 = analyzer.Simplify(EvalSet(diff_2, var_intsets).max(), 3); |
167 | |
168 | PrimExpr diff_over = |
169 | analyzer.CanProve(diff_over_2 - diff_over_1 < 0) ? diff_over_2 : diff_over_1; |
170 | |
171 | // If it is provable that the new one is strictly better than the current best one, |
172 | // then replace it. Note that we are biased towards earlier pairs which should be simpler. |
173 | if (!best_diff_over.defined() || analyzer.CanProve(diff_over - best_diff_over < 0)) { |
174 | best_lower = low_divided; |
175 | best_diff_over = diff_over; |
176 | } |
177 | } |
178 | } |
179 | |
180 | if (!best_lower.defined()) { |
181 | ICHECK(!best_diff_over.defined()); |
182 | return Range(); |
183 | } |
184 | return Range::FromMinExtent(best_lower, analyzer.Simplify(best_diff_over + 1)); |
185 | } |
186 | |
187 | TVM_REGISTER_NODE_TYPE(IntGroupBoundsNode); |
188 | |
189 | TVM_REGISTER_GLOBAL("arith.IntGroupBounds" ) |
190 | .set_body_typed([](PrimExpr coef, Array<PrimExpr> lower, Array<PrimExpr> equal, |
191 | Array<PrimExpr> upper) { |
192 | return IntGroupBounds(coef, lower, equal, upper); |
193 | }); |
194 | |
195 | TVM_REGISTER_GLOBAL("arith.IntGroupBounds_from_range" ).set_body_typed(IntGroupBounds::FromRange); |
196 | |
197 | TVM_REGISTER_GLOBAL("arith.IntGroupBounds_FindBestRange" ) |
198 | .set_body([](TVMArgs args, TVMRetValue* ret) { |
199 | ICHECK(args.size() == 1 || args.size() == 2); |
200 | IntGroupBounds bounds = args[0]; |
201 | if (args.size() == 1) { |
202 | *ret = bounds.FindBestRange(); |
203 | } else if (args.size() == 2) { |
204 | *ret = bounds.FindBestRange(args[1]); |
205 | } |
206 | }); |
207 | |
208 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
209 | .set_dispatch<IntGroupBoundsNode>([](const ObjectRef& node, ReprPrinter* p) { |
210 | auto* op = static_cast<const IntGroupBoundsNode*>(node.get()); |
211 | p->stream << "IntGroupBounds(coef=" << op->coef << ", lower=" << op->lower |
212 | << ", equal=" << op->equal << ", upper=" << op->upper << ")" ; |
213 | }); |
214 | |
215 | IntConstraints::IntConstraints(Array<Var> variables, Map<Var, Range> ranges, |
216 | Array<PrimExpr> relations) { |
217 | ObjectPtr<IntConstraintsNode> node = make_object<IntConstraintsNode>(); |
218 | if (!variables.defined()) { |
219 | variables = Array<Var>(); |
220 | } |
221 | if (!ranges.defined()) { |
222 | ranges = Map<Var, Range>(); |
223 | } |
224 | ICHECK(relations.defined()); |
225 | for (const auto& var : variables) { |
226 | ICHECK(var.dtype().is_int() || var.dtype().is_uint()) |
227 | << "Variables in IntConstraints must be integers" ; |
228 | } |
229 | node->variables = std::move(variables); |
230 | node->ranges = std::move(ranges); |
231 | node->relations = std::move(relations); |
232 | data_ = std::move(node); |
233 | } |
234 | |
235 | TVM_REGISTER_NODE_TYPE(IntConstraintsNode); |
236 | |
237 | TVM_REGISTER_GLOBAL("arith.IntConstraints" ) |
238 | .set_body_typed([](Array<Var> variables, Map<Var, Range> ranges, Array<PrimExpr> relations) { |
239 | return IntConstraints(variables, ranges, relations); |
240 | }); |
241 | |
242 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
243 | .set_dispatch<IntConstraintsNode>([](const ObjectRef& node, ReprPrinter* p) { |
244 | auto* op = static_cast<const IntConstraintsNode*>(node.get()); |
245 | p->stream << "IntConstraints(" << op->variables << ", " << op->ranges << ", " << op->relations |
246 | << ")" ; |
247 | }); |
248 | |
249 | IntConstraintsTransform::IntConstraintsTransform(IntConstraints src, IntConstraints dst, |
250 | Map<Var, PrimExpr> src_to_dst, |
251 | Map<Var, PrimExpr> dst_to_src) { |
252 | ObjectPtr<IntConstraintsTransformNode> node = make_object<IntConstraintsTransformNode>(); |
253 | node->src = std::move(src); |
254 | node->dst = std::move(dst); |
255 | node->src_to_dst = std::move(src_to_dst); |
256 | node->dst_to_src = std::move(dst_to_src); |
257 | data_ = std::move(node); |
258 | } |
259 | |
260 | IntConstraintsTransform IntConstraintsTransform::operator+( |
261 | const IntConstraintsTransform& other) const { |
262 | ICHECK(other->src.same_as(operator->()->dst)); |
263 | Map<Var, PrimExpr> dst_to_src; |
264 | Map<Var, PrimExpr> src_to_dst; |
265 | |
266 | Analyzer ana_first; |
267 | ana_first.Bind(operator->()->src->ranges); |
268 | for (auto p : other->dst_to_src) { |
269 | dst_to_src.Set(p.first, ana_first.Simplify(Substitute(p.second, operator->()->dst_to_src))); |
270 | } |
271 | |
272 | Analyzer ana_second; |
273 | ana_second.Bind(other->dst->ranges); |
274 | for (auto p : operator->()->src_to_dst) { |
275 | src_to_dst.Set(p.first, ana_second.Simplify(Substitute(p.second, other->src_to_dst))); |
276 | } |
277 | return IntConstraintsTransform(operator->()->src, other->dst, src_to_dst, dst_to_src); |
278 | } |
279 | |
280 | TVM_REGISTER_NODE_TYPE(IntConstraintsTransformNode); |
281 | |
282 | TVM_REGISTER_GLOBAL("arith.IntConstraintsTransform" ) |
283 | .set_body_typed([](IntConstraints src, IntConstraints dst, Map<Var, PrimExpr> src_to_dst, |
284 | Map<Var, PrimExpr> dst_to_src) { |
285 | return IntConstraintsTransform(src, dst, src_to_dst, dst_to_src); |
286 | }); |
287 | |
288 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
289 | .set_dispatch<IntConstraintsTransformNode>([](const ObjectRef& node, ReprPrinter* p) { |
290 | auto* op = static_cast<const IntConstraintsTransformNode*>(node.get()); |
291 | p->stream << "IntConstraintsTransform(" |
292 | << "\n\t" << op->src << "\n\t" << op->dst << "\n\t" << op->src_to_dst << "\n\t" |
293 | << op->dst_to_src << "\n)" ; |
294 | }); |
295 | |
296 | } // namespace arith |
297 | } // namespace tvm |
298 | |