1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#ifndef TVM_SUPPORT_ND_INT_SET_H_
20#define TVM_SUPPORT_ND_INT_SET_H_
21
22#include <tvm/arith/int_set.h>
23#include <tvm/ir/expr.h>
24
25#include <unordered_map>
26#include <vector>
27
28namespace tvm {
29namespace support {
30
31/*! \brief An N-dimensional integer set representing a rectangle region */
32using NDIntSet = std::vector<arith::IntSet>;
33
34/*!
35 * \brief Construct an N-dimensional integer set representing a region.
36 * \param region The region.
37 * \return The constructed set.
38 */
39inline NDIntSet NDIntSetFromRegion(const tir::Region& region) {
40 NDIntSet result;
41 result.reserve(region.size());
42 for (const Range& range : region) {
43 result.push_back(arith::IntSet::FromRange(range));
44 }
45 return result;
46}
47
48/*!
49 * \brief Construct an N-dimensional integer set representing a shape.
50 * \param shape The shape which is an array of the length of each dimension.
51 * \return The constructed set.
52 */
53inline NDIntSet NDIntSetFromShape(const Array<PrimExpr>& shape) {
54 PrimExpr zero = Integer(0);
55 NDIntSet result;
56 result.reserve(shape.size());
57 for (const PrimExpr& extent : shape) {
58 result.push_back(arith::IntSet::FromMinExtent(zero, extent));
59 }
60 return result;
61}
62
63/*!
64 * \brief Construct an N-dimensional integer set representing a point.
65 * \param indices The N-dimensional indices representing the point.
66 * \return The constructed set.
67 */
68inline NDIntSet NDIntSetFromPoint(const Array<PrimExpr>& indices) {
69 NDIntSet result;
70 result.reserve(indices.size());
71 for (const PrimExpr& index : indices) {
72 result.push_back(arith::IntSet::SinglePoint(index));
73 }
74 return result;
75}
76
77/*!
78 * \brief Create a union set of two sets, possibly relaxed. The RHS set will be combined into the
79 * LHS set.
80 * \param lhs The first N-dimensional integer set
81 * \param rhs The second N-dimensional integer set
82 */
83inline void NDIntSetUnionWith(NDIntSet* lhs, const NDIntSet& rhs) {
84 ICHECK_EQ(lhs->size(), rhs.size());
85 int ndim = rhs.size();
86 for (int i = 0; i < ndim; ++i) {
87 arith::IntSet& int_set = lhs->at(i);
88 int_set = arith::Union({int_set, rhs.at(i)});
89 }
90}
91
92/*!
93 * \brief Union a list of N-dimensional integer sets
94 * \param nd_int_sets The N-dimensional integer sets to be merged.
95 * \return The result of the union
96 */
97inline NDIntSet NDIntSetUnion(const std::vector<NDIntSet>& nd_int_sets) {
98 ICHECK(!nd_int_sets.empty());
99 int n = nd_int_sets.size();
100 if (n == 1) {
101 return nd_int_sets[0];
102 }
103 int ndim = nd_int_sets[0].size();
104 for (int i = 1; i < n; ++i) {
105 ICHECK_EQ(nd_int_sets[i].size(), ndim);
106 }
107 NDIntSet result;
108 result.reserve(ndim);
109 Array<arith::IntSet> int_sets(n, arith::IntSet{nullptr});
110 for (int dim = 0; dim < ndim; ++dim) {
111 for (int i = 0; i < n; ++i) {
112 int_sets.Set(i, nd_int_sets[i][dim]);
113 }
114 result.push_back(arith::Union(int_sets));
115 }
116 return result;
117}
118
119/*!
120 * \brief Create an empty N-dimensional integer set.
121 * \param ndim The number of dimensions.
122 * \return The constructed set.
123 */
124inline NDIntSet NDIntSetEmpty(int ndim) {
125 return std::vector<arith::IntSet>(ndim, arith::IntSet::Nothing());
126}
127
128/*!
129 * \brief The N-dimensional version of EvalSet.
130 * \param nd_int_set The N-dimensional integer set to be evaluated.
131 * \param dom_map The domain of each variable.
132 * \return An N-dimensional integer set that can cover all the possible values of the N-dimensional
133 * integer set.
134 * \sa EvalSet
135 */
136inline NDIntSet NDIntSetEval(
137 const NDIntSet& nd_int_set,
138 const std::unordered_map<const tir::VarNode*, arith::IntSet>& dom_map) {
139 NDIntSet ret;
140 ret.reserve(nd_int_set.size());
141 for (const arith::IntSet& s : nd_int_set) {
142 ret.push_back(EvalSet(s, dom_map));
143 }
144 return ret;
145}
146
147} // namespace support
148} // namespace tvm
149
150#endif // TVM_SUPPORT_ND_INT_SET_H_
151