1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/collage/index_set.cc
22 * \brief Efficient representation of a set of post-dfs indexes.
23 */
24
25#include "./index_set.h"
26
27namespace tvm {
28namespace relay {
29namespace collage {
30
31// TODO(mbs): These should operate one-word-at-a-time
32
33IndexSet::IndexSet(size_t size, const std::vector<size_t>& indexes) : bitvec_(size, false) {
34 for (size_t index : indexes) {
35 ICHECK_LT(index, bitvec_.size());
36 ICHECK(!bitvec_[index]);
37 bitvec_[index] = true;
38 }
39}
40
41IndexSet IndexSet::operator&(const IndexSet& that) const {
42 ICHECK_EQ(bitvec_.size(), that.bitvec_.size());
43 std::vector<bool> result(bitvec_.size(), false);
44 for (size_t index = 0; index < bitvec_.size(); ++index) {
45 result[index] = bitvec_[index] && that.bitvec_[index];
46 }
47 return IndexSet(result);
48}
49
50IndexSet IndexSet::operator|(const IndexSet& that) const {
51 ICHECK_EQ(bitvec_.size(), that.bitvec_.size());
52 std::vector<bool> result(bitvec_.size(), false);
53 for (size_t index = 0; index < bitvec_.size(); ++index) {
54 result[index] = bitvec_[index] || that.bitvec_[index];
55 }
56 return IndexSet(result);
57}
58
59IndexSet IndexSet::operator-(const IndexSet& that) const {
60 ICHECK_EQ(bitvec_.size(), that.bitvec_.size());
61 std::vector<bool> result(bitvec_.size());
62 for (size_t index = 0; index < bitvec_.size(); ++index) {
63 result[index] = bitvec_[index] && !that.bitvec_[index];
64 }
65 return IndexSet(result);
66}
67
68bool IndexSet::AreDisjoint(const IndexSet& that) const {
69 ICHECK_EQ(bitvec_.size(), that.bitvec_.size());
70 for (size_t index = 0; index < bitvec_.size(); index++) {
71 if (bitvec_[index] && that.bitvec_[index]) {
72 return false;
73 }
74 }
75 return true;
76}
77
78bool IndexSet::IsSubset(const IndexSet& that) const {
79 ICHECK_EQ(bitvec_.size(), that.bitvec_.size());
80 for (size_t index = 0; index < bitvec_.size(); index++) {
81 if (bitvec_[index] && !that.bitvec_[index]) {
82 return false;
83 }
84 }
85 return true;
86}
87
88bool IndexSet::Intersects(const IndexSet& that) const {
89 ICHECK_EQ(bitvec_.size(), that.bitvec_.size());
90 for (size_t index = 0; index < bitvec_.size(); index++) {
91 if (bitvec_[index] && that.bitvec_[index]) {
92 return true;
93 }
94 }
95 return false;
96}
97
98IndexSet IndexSet::Subst(size_t new_size, const IndexSubst& subst) const {
99 std::vector<bool> result(new_size, false);
100 for (PostDfsIndex index = 0; index < bitvec_.size(); ++index) {
101 if (!bitvec_[index]) {
102 continue;
103 }
104 auto itr = subst.find(index);
105 ICHECK(itr != subst.end());
106 PostDfsIndex new_index = itr->second;
107 ICHECK(new_index < new_size);
108 ICHECK(!result[new_index]);
109 result[new_index] = true;
110 }
111 return IndexSet(result);
112}
113
114size_t IndexSet::PopCount() const {
115 size_t n = 0;
116 for (size_t index = 0; index < bitvec_.size(); index++) {
117 if (bitvec_[index]) {
118 ++n;
119 }
120 }
121 return n;
122}
123
124bool IndexSet::IsZero() const {
125 for (size_t index = 0; index < bitvec_.size(); index++) {
126 if (bitvec_[index]) {
127 return false;
128 }
129 }
130 return true;
131}
132
133size_t IndexSet::FirstInsideIndex() const {
134 for (size_t index = 0; index < bitvec_.size(); index++) {
135 if (bitvec_[index]) {
136 return index;
137 }
138 }
139 return bitvec_.size();
140}
141
142size_t IndexSet::LastInsideIndex() const {
143 for (size_t i = bitvec_.size(); i > 0; i--) {
144 const size_t index = i - 1;
145 if (bitvec_[index]) {
146 return index;
147 }
148 }
149 return bitvec_.size();
150}
151
152size_t IndexSet::NextIndex(size_t index) const {
153 ICHECK_LT(index, bitvec_.size());
154 for (index++; index < bitvec_.size(); index++) {
155 if (bitvec_[index]) {
156 return index;
157 }
158 }
159 return bitvec_.size();
160}
161
162size_t IndexSet::FirstOutsideIndex() const {
163 for (size_t index = 0; index < bitvec_.size(); index++) {
164 if (!bitvec_[index]) {
165 return index;
166 }
167 }
168 return bitvec_.size();
169}
170
171bool IndexSet::operator==(const IndexSet& that) const {
172 ICHECK_EQ(bitvec_.size(), that.bitvec_.size());
173 return bitvec_ == that.bitvec_;
174}
175
176bool IndexSet::operator!=(const IndexSet& that) const {
177 ICHECK_EQ(bitvec_.size(), that.bitvec_.size());
178 return bitvec_ != that.bitvec_;
179}
180
181bool IndexSet::operator<(const IndexSet& that) const {
182 ICHECK_EQ(bitvec_.size(), that.bitvec_.size());
183 for (size_t index = 0; index < bitvec_.size(); index++) {
184 if (bitvec_[index] && !that.bitvec_[index]) {
185 return true;
186 }
187 if (!bitvec_[index] && that.bitvec_[index]) {
188 return false;
189 }
190 }
191 return false;
192}
193
194size_t IndexSet::hash() const {
195 std::hash<std::vector<bool>> h;
196 return h(bitvec_);
197}
198
199std::string IndexSet::ToString() const {
200 std::ostringstream os;
201 os << "{";
202 bool first = true;
203 for (size_t start = 0; start < bitvec_.size(); /*no-op*/) {
204 if (!bitvec_[start]) {
205 ++start;
206 continue;
207 }
208 size_t end;
209 for (end = start + 1; end < bitvec_.size() && bitvec_[end]; ++end) {
210 /*no-op*/
211 }
212 if (first) {
213 first = false;
214 } else {
215 os << ",";
216 }
217 os << start;
218 if (end > start + 2) {
219 os << ".." << (end - 1);
220 start = end;
221 } else {
222 ++start;
223 }
224 }
225 os << "}";
226 return os.str();
227}
228
229} // namespace collage
230} // namespace relay
231} // namespace tvm
232