1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/collage/index_set.cc |
22 | * \brief Efficient representation of a set of post-dfs indexes. |
23 | */ |
24 | |
25 | #include "./index_set.h" |
26 | |
27 | namespace tvm { |
28 | namespace relay { |
29 | namespace collage { |
30 | |
31 | // TODO(mbs): These should operate one-word-at-a-time |
32 | |
33 | IndexSet::IndexSet(size_t size, const std::vector<size_t>& indexes) : bitvec_(size, false) { |
34 | for (size_t index : indexes) { |
35 | ICHECK_LT(index, bitvec_.size()); |
36 | ICHECK(!bitvec_[index]); |
37 | bitvec_[index] = true; |
38 | } |
39 | } |
40 | |
41 | IndexSet IndexSet::operator&(const IndexSet& that) const { |
42 | ICHECK_EQ(bitvec_.size(), that.bitvec_.size()); |
43 | std::vector<bool> result(bitvec_.size(), false); |
44 | for (size_t index = 0; index < bitvec_.size(); ++index) { |
45 | result[index] = bitvec_[index] && that.bitvec_[index]; |
46 | } |
47 | return IndexSet(result); |
48 | } |
49 | |
50 | IndexSet IndexSet::operator|(const IndexSet& that) const { |
51 | ICHECK_EQ(bitvec_.size(), that.bitvec_.size()); |
52 | std::vector<bool> result(bitvec_.size(), false); |
53 | for (size_t index = 0; index < bitvec_.size(); ++index) { |
54 | result[index] = bitvec_[index] || that.bitvec_[index]; |
55 | } |
56 | return IndexSet(result); |
57 | } |
58 | |
59 | IndexSet IndexSet::operator-(const IndexSet& that) const { |
60 | ICHECK_EQ(bitvec_.size(), that.bitvec_.size()); |
61 | std::vector<bool> result(bitvec_.size()); |
62 | for (size_t index = 0; index < bitvec_.size(); ++index) { |
63 | result[index] = bitvec_[index] && !that.bitvec_[index]; |
64 | } |
65 | return IndexSet(result); |
66 | } |
67 | |
68 | bool IndexSet::AreDisjoint(const IndexSet& that) const { |
69 | ICHECK_EQ(bitvec_.size(), that.bitvec_.size()); |
70 | for (size_t index = 0; index < bitvec_.size(); index++) { |
71 | if (bitvec_[index] && that.bitvec_[index]) { |
72 | return false; |
73 | } |
74 | } |
75 | return true; |
76 | } |
77 | |
78 | bool IndexSet::IsSubset(const IndexSet& that) const { |
79 | ICHECK_EQ(bitvec_.size(), that.bitvec_.size()); |
80 | for (size_t index = 0; index < bitvec_.size(); index++) { |
81 | if (bitvec_[index] && !that.bitvec_[index]) { |
82 | return false; |
83 | } |
84 | } |
85 | return true; |
86 | } |
87 | |
88 | bool IndexSet::Intersects(const IndexSet& that) const { |
89 | ICHECK_EQ(bitvec_.size(), that.bitvec_.size()); |
90 | for (size_t index = 0; index < bitvec_.size(); index++) { |
91 | if (bitvec_[index] && that.bitvec_[index]) { |
92 | return true; |
93 | } |
94 | } |
95 | return false; |
96 | } |
97 | |
98 | IndexSet IndexSet::Subst(size_t new_size, const IndexSubst& subst) const { |
99 | std::vector<bool> result(new_size, false); |
100 | for (PostDfsIndex index = 0; index < bitvec_.size(); ++index) { |
101 | if (!bitvec_[index]) { |
102 | continue; |
103 | } |
104 | auto itr = subst.find(index); |
105 | ICHECK(itr != subst.end()); |
106 | PostDfsIndex new_index = itr->second; |
107 | ICHECK(new_index < new_size); |
108 | ICHECK(!result[new_index]); |
109 | result[new_index] = true; |
110 | } |
111 | return IndexSet(result); |
112 | } |
113 | |
114 | size_t IndexSet::PopCount() const { |
115 | size_t n = 0; |
116 | for (size_t index = 0; index < bitvec_.size(); index++) { |
117 | if (bitvec_[index]) { |
118 | ++n; |
119 | } |
120 | } |
121 | return n; |
122 | } |
123 | |
124 | bool IndexSet::IsZero() const { |
125 | for (size_t index = 0; index < bitvec_.size(); index++) { |
126 | if (bitvec_[index]) { |
127 | return false; |
128 | } |
129 | } |
130 | return true; |
131 | } |
132 | |
133 | size_t IndexSet::FirstInsideIndex() const { |
134 | for (size_t index = 0; index < bitvec_.size(); index++) { |
135 | if (bitvec_[index]) { |
136 | return index; |
137 | } |
138 | } |
139 | return bitvec_.size(); |
140 | } |
141 | |
142 | size_t IndexSet::LastInsideIndex() const { |
143 | for (size_t i = bitvec_.size(); i > 0; i--) { |
144 | const size_t index = i - 1; |
145 | if (bitvec_[index]) { |
146 | return index; |
147 | } |
148 | } |
149 | return bitvec_.size(); |
150 | } |
151 | |
152 | size_t IndexSet::NextIndex(size_t index) const { |
153 | ICHECK_LT(index, bitvec_.size()); |
154 | for (index++; index < bitvec_.size(); index++) { |
155 | if (bitvec_[index]) { |
156 | return index; |
157 | } |
158 | } |
159 | return bitvec_.size(); |
160 | } |
161 | |
162 | size_t IndexSet::FirstOutsideIndex() const { |
163 | for (size_t index = 0; index < bitvec_.size(); index++) { |
164 | if (!bitvec_[index]) { |
165 | return index; |
166 | } |
167 | } |
168 | return bitvec_.size(); |
169 | } |
170 | |
171 | bool IndexSet::operator==(const IndexSet& that) const { |
172 | ICHECK_EQ(bitvec_.size(), that.bitvec_.size()); |
173 | return bitvec_ == that.bitvec_; |
174 | } |
175 | |
176 | bool IndexSet::operator!=(const IndexSet& that) const { |
177 | ICHECK_EQ(bitvec_.size(), that.bitvec_.size()); |
178 | return bitvec_ != that.bitvec_; |
179 | } |
180 | |
181 | bool IndexSet::operator<(const IndexSet& that) const { |
182 | ICHECK_EQ(bitvec_.size(), that.bitvec_.size()); |
183 | for (size_t index = 0; index < bitvec_.size(); index++) { |
184 | if (bitvec_[index] && !that.bitvec_[index]) { |
185 | return true; |
186 | } |
187 | if (!bitvec_[index] && that.bitvec_[index]) { |
188 | return false; |
189 | } |
190 | } |
191 | return false; |
192 | } |
193 | |
194 | size_t IndexSet::hash() const { |
195 | std::hash<std::vector<bool>> h; |
196 | return h(bitvec_); |
197 | } |
198 | |
199 | std::string IndexSet::ToString() const { |
200 | std::ostringstream os; |
201 | os << "{" ; |
202 | bool first = true; |
203 | for (size_t start = 0; start < bitvec_.size(); /*no-op*/) { |
204 | if (!bitvec_[start]) { |
205 | ++start; |
206 | continue; |
207 | } |
208 | size_t end; |
209 | for (end = start + 1; end < bitvec_.size() && bitvec_[end]; ++end) { |
210 | /*no-op*/ |
211 | } |
212 | if (first) { |
213 | first = false; |
214 | } else { |
215 | os << "," ; |
216 | } |
217 | os << start; |
218 | if (end > start + 2) { |
219 | os << ".." << (end - 1); |
220 | start = end; |
221 | } else { |
222 | ++start; |
223 | } |
224 | } |
225 | os << "}" ; |
226 | return os.str(); |
227 | } |
228 | |
229 | } // namespace collage |
230 | } // namespace relay |
231 | } // namespace tvm |
232 | |