1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file int_constraints.cc
22 * \brief The integer constraints data structures.
23 */
24#include <tvm/arith/analyzer.h>
25#include <tvm/arith/int_solver.h>
26#include <tvm/runtime/registry.h>
27#include <tvm/tir/expr.h>
28#include <tvm/tir/expr_functor.h>
29#include <tvm/tir/op.h>
30#include <tvm/tir/stmt_functor.h>
31
32#include <algorithm>
33#include <unordered_map>
34#include <utility>
35
36#include "../tir/transforms/ir_utils.h"
37
38namespace tvm {
39namespace arith {
40
41Array<PrimExpr> AsConditions(const Array<Var>& variables, const Map<Var, IntGroupBounds>& bounds,
42 const Array<PrimExpr>& relations) {
43 Array<PrimExpr> res;
44 // use variables to keep the order of iteration
45 // so as to get rid of any non-determinism.
46 ICHECK_EQ(variables.size(), bounds.size());
47 for (const auto v : variables) {
48 ICHECK(bounds.count(v));
49 const auto& bnds = bounds[v];
50 PrimExpr lhs = bnds->coef * v;
51 for (const PrimExpr& rhs : bnds->equal) {
52 res.push_back(lhs == rhs);
53 }
54 for (const PrimExpr& rhs : bnds->lower) {
55 res.push_back(lhs >= rhs);
56 }
57 for (const PrimExpr& rhs : bnds->upper) {
58 res.push_back(lhs <= rhs);
59 }
60 }
61 for (const PrimExpr& e : relations) {
62 res.push_back(e);
63 }
64 return res;
65}
66
67IntGroupBounds::IntGroupBounds(PrimExpr coef, Array<PrimExpr> lower, Array<PrimExpr> equal,
68 Array<PrimExpr> upper) {
69 ICHECK(coef.dtype().is_int() || coef.dtype().is_uint())
70 << "Coefficient in IntGroupBounds must be integers";
71 ObjectPtr<IntGroupBoundsNode> node = make_object<IntGroupBoundsNode>();
72 node->coef = std::move(coef);
73 node->lower = std::move(lower);
74 node->equal = std::move(equal);
75 node->upper = std::move(upper);
76 data_ = std::move(node);
77}
78
79IntGroupBounds IntGroupBounds::FromRange(const Range& r) {
80 Analyzer analyzer;
81 PrimExpr coef = tir::make_const(r->min.dtype(), 1);
82 Array<PrimExpr> equal;
83 Array<PrimExpr> lower;
84 Array<PrimExpr> upper;
85 if (tir::is_one(r->extent)) {
86 equal.push_back(r->min);
87 } else {
88 lower.push_back(r->min);
89 upper.push_back(analyzer.Simplify(r->min + r->extent - 1));
90 }
91 return IntGroupBounds(coef, lower, equal, upper);
92}
93
94IntGroupBounds IntGroupBounds::operator+(const Range& r) {
95 Analyzer analyzer;
96 Array<PrimExpr> equal;
97 Array<PrimExpr> lower;
98 Array<PrimExpr> upper;
99 const PrimExpr& coef = operator->()->coef;
100 if (tir::is_one(r->extent)) {
101 equal.push_back(analyzer.Simplify(r->min * coef));
102 } else {
103 lower.push_back(analyzer.Simplify(r->min * coef));
104 upper.push_back(analyzer.Simplify((r->min + r->extent - 1) * coef));
105 }
106 for (const auto& eq : operator->()->equal) equal.push_back(eq);
107 for (const auto& lb : operator->()->lower) lower.push_back(lb);
108 for (const auto& ub : operator->()->upper) upper.push_back(ub);
109 return IntGroupBounds(coef, lower, equal, upper);
110}
111
112IntGroupBounds IntGroupBounds::Substitute(const Map<Var, PrimExpr>& subst) const {
113 auto apply_fun = [&subst](const PrimExpr& e) { return tir::Substitute(e, subst); };
114 return IntGroupBounds(tir::Substitute(operator->()->coef, subst),
115 tir::UpdateArray(operator->()->lower, apply_fun),
116 tir::UpdateArray(operator->()->equal, apply_fun),
117 tir::UpdateArray(operator->()->upper, apply_fun));
118}
119
120Range IntGroupBounds::FindBestRange(const Map<Var, Range>& vranges_addl) const {
121 Analyzer analyzer;
122 analyzer.Bind(vranges_addl);
123
124 std::unordered_map<const VarNode*, IntSet> var_intsets;
125 for (auto kv : vranges_addl) {
126 var_intsets[kv.first.get()] = IntSet::FromRange(kv.second);
127 }
128
129 const Array<PrimExpr>& equal = operator->()->equal;
130 const PrimExpr& coef = operator->()->coef;
131
132 std::vector<PrimExpr> lowers(equal.begin(), equal.end());
133 std::vector<PrimExpr> uppers(equal.begin(), equal.end());
134 for (const auto& expr : operator->()->lower) {
135 lowers.push_back(expr);
136 }
137 for (const auto& expr : operator->()->upper) {
138 uppers.push_back(expr);
139 }
140
141 if (lowers.size() == 1 && uppers.size() == 1 && tir::is_one(coef)) {
142 return Range(analyzer.Simplify(lowers[0]), analyzer.Simplify(uppers[0] + 1));
143 }
144
145 // Here we will try all pairs of lower and upper bounds and find the best pair, that is, the
146 // pair with the minimal difference between the upper and the lower.
147 // Note that the bounds are for v, not for v*coef
148
149 // The lower bound of the best pair so far
150 PrimExpr best_lower;
151 // The difference between the upper and the lower of the best pair, maybe overapproximation
152 PrimExpr best_diff_over;
153
154 for (const PrimExpr& low : lowers) {
155 for (const PrimExpr& upp : uppers) {
156 PrimExpr diff_1 = analyzer.Simplify(floordiv(upp - low, coef), 3);
157 // Since diff may depend on some other variables, we compute its overapproximation
158 PrimExpr diff_over_1 = analyzer.Simplify(EvalSet(diff_1, var_intsets).max(), 3);
159
160 // low is the lower bound for v*coef, but we need the lower bound for v.
161 // We use rounding-up division to compute it. Since we want to use a single formula
162 PrimExpr low_divided = analyzer.Simplify(floordiv(low + coef - 1, coef), 3);
163
164 // Compute another difference which may be more precise (or not).
165 PrimExpr diff_2 = analyzer.Simplify(floordiv(upp, coef) - low_divided, 3);
166 PrimExpr diff_over_2 = analyzer.Simplify(EvalSet(diff_2, var_intsets).max(), 3);
167
168 PrimExpr diff_over =
169 analyzer.CanProve(diff_over_2 - diff_over_1 < 0) ? diff_over_2 : diff_over_1;
170
171 // If it is provable that the new one is strictly better than the current best one,
172 // then replace it. Note that we are biased towards earlier pairs which should be simpler.
173 if (!best_diff_over.defined() || analyzer.CanProve(diff_over - best_diff_over < 0)) {
174 best_lower = low_divided;
175 best_diff_over = diff_over;
176 }
177 }
178 }
179
180 if (!best_lower.defined()) {
181 ICHECK(!best_diff_over.defined());
182 return Range();
183 }
184 return Range::FromMinExtent(best_lower, analyzer.Simplify(best_diff_over + 1));
185}
186
187TVM_REGISTER_NODE_TYPE(IntGroupBoundsNode);
188
189TVM_REGISTER_GLOBAL("arith.IntGroupBounds")
190 .set_body_typed([](PrimExpr coef, Array<PrimExpr> lower, Array<PrimExpr> equal,
191 Array<PrimExpr> upper) {
192 return IntGroupBounds(coef, lower, equal, upper);
193 });
194
195TVM_REGISTER_GLOBAL("arith.IntGroupBounds_from_range").set_body_typed(IntGroupBounds::FromRange);
196
197TVM_REGISTER_GLOBAL("arith.IntGroupBounds_FindBestRange")
198 .set_body([](TVMArgs args, TVMRetValue* ret) {
199 ICHECK(args.size() == 1 || args.size() == 2);
200 IntGroupBounds bounds = args[0];
201 if (args.size() == 1) {
202 *ret = bounds.FindBestRange();
203 } else if (args.size() == 2) {
204 *ret = bounds.FindBestRange(args[1]);
205 }
206 });
207
208TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
209 .set_dispatch<IntGroupBoundsNode>([](const ObjectRef& node, ReprPrinter* p) {
210 auto* op = static_cast<const IntGroupBoundsNode*>(node.get());
211 p->stream << "IntGroupBounds(coef=" << op->coef << ", lower=" << op->lower
212 << ", equal=" << op->equal << ", upper=" << op->upper << ")";
213 });
214
215IntConstraints::IntConstraints(Array<Var> variables, Map<Var, Range> ranges,
216 Array<PrimExpr> relations) {
217 ObjectPtr<IntConstraintsNode> node = make_object<IntConstraintsNode>();
218 if (!variables.defined()) {
219 variables = Array<Var>();
220 }
221 if (!ranges.defined()) {
222 ranges = Map<Var, Range>();
223 }
224 ICHECK(relations.defined());
225 for (const auto& var : variables) {
226 ICHECK(var.dtype().is_int() || var.dtype().is_uint())
227 << "Variables in IntConstraints must be integers";
228 }
229 node->variables = std::move(variables);
230 node->ranges = std::move(ranges);
231 node->relations = std::move(relations);
232 data_ = std::move(node);
233}
234
235TVM_REGISTER_NODE_TYPE(IntConstraintsNode);
236
237TVM_REGISTER_GLOBAL("arith.IntConstraints")
238 .set_body_typed([](Array<Var> variables, Map<Var, Range> ranges, Array<PrimExpr> relations) {
239 return IntConstraints(variables, ranges, relations);
240 });
241
242TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
243 .set_dispatch<IntConstraintsNode>([](const ObjectRef& node, ReprPrinter* p) {
244 auto* op = static_cast<const IntConstraintsNode*>(node.get());
245 p->stream << "IntConstraints(" << op->variables << ", " << op->ranges << ", " << op->relations
246 << ")";
247 });
248
249IntConstraintsTransform::IntConstraintsTransform(IntConstraints src, IntConstraints dst,
250 Map<Var, PrimExpr> src_to_dst,
251 Map<Var, PrimExpr> dst_to_src) {
252 ObjectPtr<IntConstraintsTransformNode> node = make_object<IntConstraintsTransformNode>();
253 node->src = std::move(src);
254 node->dst = std::move(dst);
255 node->src_to_dst = std::move(src_to_dst);
256 node->dst_to_src = std::move(dst_to_src);
257 data_ = std::move(node);
258}
259
260IntConstraintsTransform IntConstraintsTransform::operator+(
261 const IntConstraintsTransform& other) const {
262 ICHECK(other->src.same_as(operator->()->dst));
263 Map<Var, PrimExpr> dst_to_src;
264 Map<Var, PrimExpr> src_to_dst;
265
266 Analyzer ana_first;
267 ana_first.Bind(operator->()->src->ranges);
268 for (auto p : other->dst_to_src) {
269 dst_to_src.Set(p.first, ana_first.Simplify(Substitute(p.second, operator->()->dst_to_src)));
270 }
271
272 Analyzer ana_second;
273 ana_second.Bind(other->dst->ranges);
274 for (auto p : operator->()->src_to_dst) {
275 src_to_dst.Set(p.first, ana_second.Simplify(Substitute(p.second, other->src_to_dst)));
276 }
277 return IntConstraintsTransform(operator->()->src, other->dst, src_to_dst, dst_to_src);
278}
279
280TVM_REGISTER_NODE_TYPE(IntConstraintsTransformNode);
281
282TVM_REGISTER_GLOBAL("arith.IntConstraintsTransform")
283 .set_body_typed([](IntConstraints src, IntConstraints dst, Map<Var, PrimExpr> src_to_dst,
284 Map<Var, PrimExpr> dst_to_src) {
285 return IntConstraintsTransform(src, dst, src_to_dst, dst_to_src);
286 });
287
288TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
289 .set_dispatch<IntConstraintsTransformNode>([](const ObjectRef& node, ReprPrinter* p) {
290 auto* op = static_cast<const IntConstraintsTransformNode*>(node.get());
291 p->stream << "IntConstraintsTransform("
292 << "\n\t" << op->src << "\n\t" << op->dst << "\n\t" << op->src_to_dst << "\n\t"
293 << op->dst_to_src << "\n)";
294 });
295
296} // namespace arith
297} // namespace tvm
298