1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/arith/int_set.h
22 * \brief Integer set
23 */
24#ifndef TVM_ARITH_INT_SET_H_
25#define TVM_ARITH_INT_SET_H_
26
27#include <tvm/ir/expr.h>
28#include <tvm/tir/expr.h>
29
30#include <unordered_map>
31
32namespace tvm {
33namespace arith {
34
35using tir::IterVar;
36using tir::Var;
37using tir::VarNode;
38
39class Analyzer;
40
41//-----------------------------------------------
42// Integer set data structure.
43//
44// This is a API build on top of the base
45// integer analysis API to provide set analysis.
46//------------------------------------------------
47/*!
48 * \brief Sign type of an integer expression.
49 */
50enum SignType { kPositive, kNegative, kZero, kUnknown };
51
52/*!
53 * \brief Base class of all Integer set containers.
54 * represent a set of integers in one dimension.
55 * \sa IntSet
56 */
57class IntSetNode : public Object {
58 public:
59 static constexpr const char* _type_key = "IntSet";
60 static constexpr bool _type_has_method_sequal_reduce = false;
61 TVM_DECLARE_BASE_OBJECT_INFO(IntSetNode, Object);
62};
63
64/*!
65 * \brief Managed reference to IntSetNode.
66 * \sa IntSetNode
67 */
68class IntSet : public ObjectRef {
69 public:
70 /*!
71 * \brief Find a range that covers the region.
72 * \param max_range The range to be covered.
73 * \return The covering range.
74 */
75 Range CoverRange(Range max_range) const;
76 /*! \return Lower bound of the set */
77 PrimExpr min() const;
78 /*! \return upper bound of the set */
79 PrimExpr max() const;
80 /*! \return The sign of the elements in the integer set */
81 SignType GetSignType() const;
82 /*! \return Whether the set represent nothing */
83 bool IsNothing() const;
84 /*! \return Whether the set represent everything */
85 bool IsEverything() const;
86 /*! \return Whether the set is a single point */
87 bool IsSinglePoint() const;
88 /*! \return Whether the set is proved to be bigger than 0 */
89 bool CanProvePositive() const;
90 /*! \return Whether the set is proved to be smaller than 0 */
91 bool CanProveNegative() const;
92 /*! \return Whether the set is proved to be smaller than or equal to 0 */
93 bool CanProveNonPositive() const;
94 /*! \return Whether the set is proved to be larger than or equal to 0 */
95 bool CanProveNonNegative() const;
96 /*! \return Whether the set has upper bound. */
97 bool HasUpperBound() const;
98 /*! \return Whether the set has lower bound. */
99 bool HasLowerBound() const;
100
101 /*!
102 * \brief The single point value, call only if IsSinglePoint is true
103 * \return The point value.
104 */
105 PrimExpr PointValue() const;
106 /*!
107 * \brief Try to match IntSet with range r.
108 *
109 * \note It is guanrateed that IntSet::FromRange(r).MatchRange(r) == true
110 * \return true if we can prove they are the same.
111 */
112 bool MatchRange(const tvm::Range& r) const;
113 /*! \return The set contains nothing */
114 static IntSet Nothing();
115 /*! \return The set contains everything */
116 static IntSet Everything();
117 /*!
118 * \brief construct a point set.
119 * \param point The point in the set.
120 * \return construct a single point set
121 */
122 static IntSet SinglePoint(PrimExpr point);
123 /*!
124 * \brief construct a integer set from vector expression.
125 * \param vec The vector expression, can also be single point.
126 * \return The result set containing the indices in the vector.
127 */
128 static IntSet Vector(PrimExpr vec);
129 /*!
130 * \brief Construct a set representing a range [min, min + extent).
131 * \param min The minimum of the range range
132 * \param extent The extent of the range.
133 * \return The constructed set.
134 */
135 static IntSet FromMinExtent(PrimExpr min, PrimExpr extent);
136 /*!
137 * \brief Construct a set representing a range.
138 * \param r The range
139 * \return The constructed set.
140 */
141 static IntSet FromRange(tvm::Range r);
142 /*!
143 * \brief Construct a set representing a interval.
144 * \param min The minimum value of the interval.
145 * \param max The maximum value of the interval.
146 * \return The constructed set.
147 */
148 static IntSet Interval(PrimExpr min, PrimExpr max);
149
150 TVM_DEFINE_OBJECT_REF_METHODS(IntSet, ObjectRef, IntSetNode);
151};
152
153//-----------------------------------------------
154// Integer set legacy API.
155//------------------------------------------------
156/*!
157 * \brief Convert std::unordered_map<const VarNode*, IntSet> to Map<Var, IntSet>
158 *
159 * \param dom_map The domain map to convert.
160 * \return The converted map.
161 */
162Map<Var, IntSet> ConvertDomMap(const std::unordered_map<const VarNode*, IntSet>& dom_map);
163/*!
164 * \brief Find an symbolic integer set that contains all possible values of
165 * e given the domain of each iteration variables.
166 *
167 * \param e The expression to be evaluated.
168 * \param dom_map The domain of each variable.
169 * \return An integer set that can cover all the possible values of e.
170 */
171IntSet EvalSet(PrimExpr e, const Map<IterVar, IntSet>& dom_map);
172/*!
173 * \brief Find an symbolic integer set that contains all possible values of
174 * e given the domain of each variables.
175 *
176 * \param e The expression to be evaluated.
177 * \param dom_map The domain of each variable.
178 * \return An integer set that can cover all the possible values of e.
179 */
180IntSet EvalSet(PrimExpr e, const Map<Var, IntSet>& dom_map);
181/*!
182 * \brief Same as EvalSet, but takes unordered_map
183 *
184 * \param e The expression to be evaluated.
185 * \param dom_map The domain of each variable.
186 * \return An integer set that can cover all the possible values of e.
187 */
188IntSet EvalSet(PrimExpr e, const std::unordered_map<const tir::VarNode*, IntSet>& dom_map);
189/*!
190 * \brief Find an symbolic integer set that contains is union over
191 * all the possible conditional values in dom_map.
192 *
193 * \param r The initial range.
194 * \param dom_map The domain of each variable.
195 * \return An integer set that can cover all the possible values.
196 */
197IntSet EvalSet(Range r, const Map<IterVar, IntSet>& dom_map);
198
199/*!
200 * \brief Find an symbolic integer set that contains is union over
201 * all the possible conditional values in dom_map.
202 *
203 * \param s The initial set.
204 * \param dom_map The domain of each variable.
205 * \return An integer set that can cover all the possible values.
206 */
207IntSet EvalSet(IntSet s, const std::unordered_map<const VarNode*, IntSet>& dom_map);
208/*!
209 * \brief Same as EvalSet, but takes unordered_map
210 *
211 * \param r The range to be evaluated.
212 * \param dom_map The domain of each variable.
213 * \return An integer set that can cover all the possible values of e.
214 */
215IntSet EvalSet(Range r, const std::unordered_map<const VarNode*, IntSet>& dom_map);
216/*!
217 * \brief Same as EvalSet, but takes Array<Range>
218 *
219 * \param region The range to be evaluated.
220 * \param dom_map The domain of each variable.
221 * \return An array of integer sets that can cover all the possible values.
222 */
223Array<IntSet> EvalSet(const Array<Range>& region, const Map<Var, IntSet>& dom_map);
224/*! \brief Map from Expr to IntSet */
225using ExprIntSetMap = std::unordered_map<PrimExpr, IntSet, ObjectPtrHash, ObjectPtrEqual>;
226/*!
227 * \brief Find the integer set of every sub-expression, given the
228 * domain of each iteration variables.
229 *
230 * \param e The expression to be evaluated.
231 * \param dom_map The domain of each variable.
232 * \return the map from the expression to its possible value.
233 */
234ExprIntSetMap EvalSetForEachSubExpr(PrimExpr e,
235 const std::unordered_map<const VarNode*, IntSet>& dom_map);
236
237/*!
238 * \brief Create a union set of all sets, possibly relaxed
239 * \param sets The sets to be combined
240 * \return the set after union
241 */
242IntSet Union(const Array<IntSet>& sets);
243
244/*!
245 * \brief The union of N-dimensional integer sets
246 * \param nd_int_sets A list of N-dimensional integer sets
247 * \return An N-dimensional integer set as the result of union
248 */
249Array<IntSet> UnionRegion(const Array<Array<IntSet>>& nd_int_sets);
250
251/*!
252 * \brief Create a lower-bound of union set, where some of the segments may be dropped
253 * \param sets The sets to be combined
254 * \return the set after union
255 */
256IntSet UnionLowerBound(const Array<IntSet>& sets);
257
258/*!
259 * \brief The union of N-dimensional integer sets
260 * \param nd_int_sets A list of N-dimensional integer sets
261 * \return An N-dimensional integer set as the result of union
262 */
263Array<IntSet> UnionRegionLowerBound(const Array<Array<IntSet>>& nd_int_sets);
264
265/*!
266 * \brief Create an intersected set of all sets
267 * \param sets The sets to be intersected
268 * \return the set after intersected
269 */
270IntSet Intersect(const Array<IntSet>& sets);
271
272/*!
273 * \brief Converts the Ranges to IntSets
274 * \param var_dom The ranges of variables
275 * \return The integer sets of the variables
276 */
277Map<Var, arith::IntSet> AsIntSet(const Map<Var, Range>& var_dom);
278
279/*!
280 * \brief Analyze the region with affine map, given the domain of variables and their predicate.
281 * The result should be strict, i.e. no region is discarded or relaxed.
282 * \param region The region to be analyzed
283 * \param var_dom The ranges of the variables
284 * \param predicate The predicate for the affine map
285 * \param analyzer The analyzer used
286 * \return NullOpt if the detection fails, or an array of arith::IntSet as the result of analysis
287 */
288TVM_DLL Optional<Array<IntSet>> EstimateRegionStrictBound(const Array<Range>& region,
289 const Map<Var, Range>& var_dom,
290 const PrimExpr& predicate,
291 arith::Analyzer* analyzer);
292
293/*!
294 * \brief Analyze the region with affine map, given the domain of variables and their predicate.
295 * Some subregion may be discarded during the lower-bound analysis.
296 * \param region The region to be analyzed
297 * \param var_dom The ranges of the variables
298 * \param predicate The predicate for the affine map
299 * \param analyzer The analyzer used
300 * \return NullOpt if the detection fails, or an array of arith::IntSet as the result of analysis
301 */
302TVM_DLL Optional<Array<IntSet>> EstimateRegionLowerBound(const Array<Range>& region,
303 const Map<Var, Range>& var_dom,
304 const PrimExpr& predicate,
305 arith::Analyzer* analyzer);
306
307/*!
308 * \brief Analyze the region with affine map, given the domain of variables and their predicate
309 * Relaxation of the region may be used in upper-bound analysis, i.e. some extra region may be added
310 * to the result.
311 * \param region The region to be analyzed
312 * \param var_dom The ranges of the variables
313 * \param predicate The predicate for the affine map
314 * \param analyzer The analyzer used
315 * \return an array of arith::IntSet as the result of analysis
316 */
317TVM_DLL Array<IntSet> EstimateRegionUpperBound(const Array<Range>& region,
318 const Map<Var, Range>& var_dom,
319 const PrimExpr& predicate,
320 arith::Analyzer* analyzer);
321
322} // namespace arith
323} // namespace tvm
324#endif // TVM_ARITH_INT_SET_H_
325