1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/arith/int_set.h |
22 | * \brief Integer set |
23 | */ |
24 | #ifndef TVM_ARITH_INT_SET_H_ |
25 | #define TVM_ARITH_INT_SET_H_ |
26 | |
27 | #include <tvm/ir/expr.h> |
28 | #include <tvm/tir/expr.h> |
29 | |
30 | #include <unordered_map> |
31 | |
32 | namespace tvm { |
33 | namespace arith { |
34 | |
35 | using tir::IterVar; |
36 | using tir::Var; |
37 | using tir::VarNode; |
38 | |
39 | class Analyzer; |
40 | |
41 | //----------------------------------------------- |
42 | // Integer set data structure. |
43 | // |
44 | // This is a API build on top of the base |
45 | // integer analysis API to provide set analysis. |
46 | //------------------------------------------------ |
47 | /*! |
48 | * \brief Sign type of an integer expression. |
49 | */ |
50 | enum SignType { kPositive, kNegative, kZero, kUnknown }; |
51 | |
52 | /*! |
53 | * \brief Base class of all Integer set containers. |
54 | * represent a set of integers in one dimension. |
55 | * \sa IntSet |
56 | */ |
57 | class IntSetNode : public Object { |
58 | public: |
59 | static constexpr const char* _type_key = "IntSet" ; |
60 | static constexpr bool _type_has_method_sequal_reduce = false; |
61 | TVM_DECLARE_BASE_OBJECT_INFO(IntSetNode, Object); |
62 | }; |
63 | |
64 | /*! |
65 | * \brief Managed reference to IntSetNode. |
66 | * \sa IntSetNode |
67 | */ |
68 | class IntSet : public ObjectRef { |
69 | public: |
70 | /*! |
71 | * \brief Find a range that covers the region. |
72 | * \param max_range The range to be covered. |
73 | * \return The covering range. |
74 | */ |
75 | Range CoverRange(Range max_range) const; |
76 | /*! \return Lower bound of the set */ |
77 | PrimExpr min() const; |
78 | /*! \return upper bound of the set */ |
79 | PrimExpr max() const; |
80 | /*! \return The sign of the elements in the integer set */ |
81 | SignType GetSignType() const; |
82 | /*! \return Whether the set represent nothing */ |
83 | bool IsNothing() const; |
84 | /*! \return Whether the set represent everything */ |
85 | bool IsEverything() const; |
86 | /*! \return Whether the set is a single point */ |
87 | bool IsSinglePoint() const; |
88 | /*! \return Whether the set is proved to be bigger than 0 */ |
89 | bool CanProvePositive() const; |
90 | /*! \return Whether the set is proved to be smaller than 0 */ |
91 | bool CanProveNegative() const; |
92 | /*! \return Whether the set is proved to be smaller than or equal to 0 */ |
93 | bool CanProveNonPositive() const; |
94 | /*! \return Whether the set is proved to be larger than or equal to 0 */ |
95 | bool CanProveNonNegative() const; |
96 | /*! \return Whether the set has upper bound. */ |
97 | bool HasUpperBound() const; |
98 | /*! \return Whether the set has lower bound. */ |
99 | bool HasLowerBound() const; |
100 | |
101 | /*! |
102 | * \brief The single point value, call only if IsSinglePoint is true |
103 | * \return The point value. |
104 | */ |
105 | PrimExpr PointValue() const; |
106 | /*! |
107 | * \brief Try to match IntSet with range r. |
108 | * |
109 | * \note It is guanrateed that IntSet::FromRange(r).MatchRange(r) == true |
110 | * \return true if we can prove they are the same. |
111 | */ |
112 | bool MatchRange(const tvm::Range& r) const; |
113 | /*! \return The set contains nothing */ |
114 | static IntSet Nothing(); |
115 | /*! \return The set contains everything */ |
116 | static IntSet Everything(); |
117 | /*! |
118 | * \brief construct a point set. |
119 | * \param point The point in the set. |
120 | * \return construct a single point set |
121 | */ |
122 | static IntSet SinglePoint(PrimExpr point); |
123 | /*! |
124 | * \brief construct a integer set from vector expression. |
125 | * \param vec The vector expression, can also be single point. |
126 | * \return The result set containing the indices in the vector. |
127 | */ |
128 | static IntSet Vector(PrimExpr vec); |
129 | /*! |
130 | * \brief Construct a set representing a range [min, min + extent). |
131 | * \param min The minimum of the range range |
132 | * \param extent The extent of the range. |
133 | * \return The constructed set. |
134 | */ |
135 | static IntSet FromMinExtent(PrimExpr min, PrimExpr extent); |
136 | /*! |
137 | * \brief Construct a set representing a range. |
138 | * \param r The range |
139 | * \return The constructed set. |
140 | */ |
141 | static IntSet FromRange(tvm::Range r); |
142 | /*! |
143 | * \brief Construct a set representing a interval. |
144 | * \param min The minimum value of the interval. |
145 | * \param max The maximum value of the interval. |
146 | * \return The constructed set. |
147 | */ |
148 | static IntSet Interval(PrimExpr min, PrimExpr max); |
149 | |
150 | TVM_DEFINE_OBJECT_REF_METHODS(IntSet, ObjectRef, IntSetNode); |
151 | }; |
152 | |
153 | //----------------------------------------------- |
154 | // Integer set legacy API. |
155 | //------------------------------------------------ |
156 | /*! |
157 | * \brief Convert std::unordered_map<const VarNode*, IntSet> to Map<Var, IntSet> |
158 | * |
159 | * \param dom_map The domain map to convert. |
160 | * \return The converted map. |
161 | */ |
162 | Map<Var, IntSet> ConvertDomMap(const std::unordered_map<const VarNode*, IntSet>& dom_map); |
163 | /*! |
164 | * \brief Find an symbolic integer set that contains all possible values of |
165 | * e given the domain of each iteration variables. |
166 | * |
167 | * \param e The expression to be evaluated. |
168 | * \param dom_map The domain of each variable. |
169 | * \return An integer set that can cover all the possible values of e. |
170 | */ |
171 | IntSet EvalSet(PrimExpr e, const Map<IterVar, IntSet>& dom_map); |
172 | /*! |
173 | * \brief Find an symbolic integer set that contains all possible values of |
174 | * e given the domain of each variables. |
175 | * |
176 | * \param e The expression to be evaluated. |
177 | * \param dom_map The domain of each variable. |
178 | * \return An integer set that can cover all the possible values of e. |
179 | */ |
180 | IntSet EvalSet(PrimExpr e, const Map<Var, IntSet>& dom_map); |
181 | /*! |
182 | * \brief Same as EvalSet, but takes unordered_map |
183 | * |
184 | * \param e The expression to be evaluated. |
185 | * \param dom_map The domain of each variable. |
186 | * \return An integer set that can cover all the possible values of e. |
187 | */ |
188 | IntSet EvalSet(PrimExpr e, const std::unordered_map<const tir::VarNode*, IntSet>& dom_map); |
189 | /*! |
190 | * \brief Find an symbolic integer set that contains is union over |
191 | * all the possible conditional values in dom_map. |
192 | * |
193 | * \param r The initial range. |
194 | * \param dom_map The domain of each variable. |
195 | * \return An integer set that can cover all the possible values. |
196 | */ |
197 | IntSet EvalSet(Range r, const Map<IterVar, IntSet>& dom_map); |
198 | |
199 | /*! |
200 | * \brief Find an symbolic integer set that contains is union over |
201 | * all the possible conditional values in dom_map. |
202 | * |
203 | * \param s The initial set. |
204 | * \param dom_map The domain of each variable. |
205 | * \return An integer set that can cover all the possible values. |
206 | */ |
207 | IntSet EvalSet(IntSet s, const std::unordered_map<const VarNode*, IntSet>& dom_map); |
208 | /*! |
209 | * \brief Same as EvalSet, but takes unordered_map |
210 | * |
211 | * \param r The range to be evaluated. |
212 | * \param dom_map The domain of each variable. |
213 | * \return An integer set that can cover all the possible values of e. |
214 | */ |
215 | IntSet EvalSet(Range r, const std::unordered_map<const VarNode*, IntSet>& dom_map); |
216 | /*! |
217 | * \brief Same as EvalSet, but takes Array<Range> |
218 | * |
219 | * \param region The range to be evaluated. |
220 | * \param dom_map The domain of each variable. |
221 | * \return An array of integer sets that can cover all the possible values. |
222 | */ |
223 | Array<IntSet> EvalSet(const Array<Range>& region, const Map<Var, IntSet>& dom_map); |
224 | /*! \brief Map from Expr to IntSet */ |
225 | using ExprIntSetMap = std::unordered_map<PrimExpr, IntSet, ObjectPtrHash, ObjectPtrEqual>; |
226 | /*! |
227 | * \brief Find the integer set of every sub-expression, given the |
228 | * domain of each iteration variables. |
229 | * |
230 | * \param e The expression to be evaluated. |
231 | * \param dom_map The domain of each variable. |
232 | * \return the map from the expression to its possible value. |
233 | */ |
234 | ExprIntSetMap EvalSetForEachSubExpr(PrimExpr e, |
235 | const std::unordered_map<const VarNode*, IntSet>& dom_map); |
236 | |
237 | /*! |
238 | * \brief Create a union set of all sets, possibly relaxed |
239 | * \param sets The sets to be combined |
240 | * \return the set after union |
241 | */ |
242 | IntSet Union(const Array<IntSet>& sets); |
243 | |
244 | /*! |
245 | * \brief The union of N-dimensional integer sets |
246 | * \param nd_int_sets A list of N-dimensional integer sets |
247 | * \return An N-dimensional integer set as the result of union |
248 | */ |
249 | Array<IntSet> UnionRegion(const Array<Array<IntSet>>& nd_int_sets); |
250 | |
251 | /*! |
252 | * \brief Create a lower-bound of union set, where some of the segments may be dropped |
253 | * \param sets The sets to be combined |
254 | * \return the set after union |
255 | */ |
256 | IntSet UnionLowerBound(const Array<IntSet>& sets); |
257 | |
258 | /*! |
259 | * \brief The union of N-dimensional integer sets |
260 | * \param nd_int_sets A list of N-dimensional integer sets |
261 | * \return An N-dimensional integer set as the result of union |
262 | */ |
263 | Array<IntSet> UnionRegionLowerBound(const Array<Array<IntSet>>& nd_int_sets); |
264 | |
265 | /*! |
266 | * \brief Create an intersected set of all sets |
267 | * \param sets The sets to be intersected |
268 | * \return the set after intersected |
269 | */ |
270 | IntSet Intersect(const Array<IntSet>& sets); |
271 | |
272 | /*! |
273 | * \brief Converts the Ranges to IntSets |
274 | * \param var_dom The ranges of variables |
275 | * \return The integer sets of the variables |
276 | */ |
277 | Map<Var, arith::IntSet> AsIntSet(const Map<Var, Range>& var_dom); |
278 | |
279 | /*! |
280 | * \brief Analyze the region with affine map, given the domain of variables and their predicate. |
281 | * The result should be strict, i.e. no region is discarded or relaxed. |
282 | * \param region The region to be analyzed |
283 | * \param var_dom The ranges of the variables |
284 | * \param predicate The predicate for the affine map |
285 | * \param analyzer The analyzer used |
286 | * \return NullOpt if the detection fails, or an array of arith::IntSet as the result of analysis |
287 | */ |
288 | TVM_DLL Optional<Array<IntSet>> EstimateRegionStrictBound(const Array<Range>& region, |
289 | const Map<Var, Range>& var_dom, |
290 | const PrimExpr& predicate, |
291 | arith::Analyzer* analyzer); |
292 | |
293 | /*! |
294 | * \brief Analyze the region with affine map, given the domain of variables and their predicate. |
295 | * Some subregion may be discarded during the lower-bound analysis. |
296 | * \param region The region to be analyzed |
297 | * \param var_dom The ranges of the variables |
298 | * \param predicate The predicate for the affine map |
299 | * \param analyzer The analyzer used |
300 | * \return NullOpt if the detection fails, or an array of arith::IntSet as the result of analysis |
301 | */ |
302 | TVM_DLL Optional<Array<IntSet>> EstimateRegionLowerBound(const Array<Range>& region, |
303 | const Map<Var, Range>& var_dom, |
304 | const PrimExpr& predicate, |
305 | arith::Analyzer* analyzer); |
306 | |
307 | /*! |
308 | * \brief Analyze the region with affine map, given the domain of variables and their predicate |
309 | * Relaxation of the region may be used in upper-bound analysis, i.e. some extra region may be added |
310 | * to the result. |
311 | * \param region The region to be analyzed |
312 | * \param var_dom The ranges of the variables |
313 | * \param predicate The predicate for the affine map |
314 | * \param analyzer The analyzer used |
315 | * \return an array of arith::IntSet as the result of analysis |
316 | */ |
317 | TVM_DLL Array<IntSet> EstimateRegionUpperBound(const Array<Range>& region, |
318 | const Map<Var, Range>& var_dom, |
319 | const PrimExpr& predicate, |
320 | arith::Analyzer* analyzer); |
321 | |
322 | } // namespace arith |
323 | } // namespace tvm |
324 | #endif // TVM_ARITH_INT_SET_H_ |
325 | |