1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/collage/index_set.h |
22 | * \brief Efficient representation of a set of post-dfs indexes. |
23 | */ |
24 | |
25 | #ifndef TVM_RELAY_COLLAGE_INDEX_SET_H_ |
26 | #define TVM_RELAY_COLLAGE_INDEX_SET_H_ |
27 | |
28 | #include <string> |
29 | #include <unordered_map> |
30 | #include <utility> |
31 | #include <vector> |
32 | |
33 | #include "../ir/dataflow_matcher_impl.h" |
34 | #include "../ir/indexed_graph.h" |
35 | |
36 | namespace tvm { |
37 | namespace relay { |
38 | namespace collage { |
39 | |
40 | using IndexSubst = std::unordered_map<size_t, size_t>; |
41 | |
42 | class IndexSet { |
43 | public: |
44 | IndexSet() = default; |
45 | explicit IndexSet(size_t size) : bitvec_(size, false) {} |
46 | IndexSet(size_t size, const std::vector<size_t>& indexes); |
47 | |
48 | IndexSet operator&(const IndexSet& that) const; |
49 | IndexSet operator|(const IndexSet& that) const; |
50 | IndexSet operator-(const IndexSet& that) const; |
51 | bool AreDisjoint(const IndexSet& that) const; |
52 | bool IsSubset(const IndexSet& that) const; |
53 | bool Intersects(const IndexSet& that) const; |
54 | |
55 | bool operator[](size_t index) const { |
56 | ICHECK_LT(index, bitvec_.size()); |
57 | return bitvec_[index]; |
58 | } |
59 | |
60 | IndexSet& Add(size_t index) { |
61 | ICHECK_LT(index, bitvec_.size()); |
62 | bitvec_[index] = true; |
63 | return *this; |
64 | } |
65 | |
66 | IndexSet Subst(size_t new_size, const IndexSubst& subst) const; |
67 | |
68 | size_t end_index() const { return bitvec_.size(); } |
69 | size_t PopCount() const; |
70 | bool IsZero() const; |
71 | size_t FirstInsideIndex() const; |
72 | size_t LastInsideIndex() const; |
73 | size_t NextIndex(size_t index) const; |
74 | size_t FirstOutsideIndex() const; |
75 | bool operator==(const IndexSet& that) const; |
76 | bool operator!=(const IndexSet& that) const; |
77 | bool operator<(const IndexSet& that) const; |
78 | size_t hash() const; |
79 | std::string ToString() const; |
80 | |
81 | struct IndexSetIterator { |
82 | const IndexSet* set; |
83 | size_t i; |
84 | |
85 | size_t operator*() const { |
86 | ICHECK_LT(i, set->end_index()); |
87 | return i; |
88 | } |
89 | |
90 | const IndexSetIterator& operator++() { |
91 | ICHECK_LT(i, set->end_index()); |
92 | i = set->NextIndex(i); |
93 | return *this; |
94 | } |
95 | |
96 | bool operator==(const IndexSetIterator& that) const { |
97 | ICHECK(set == that.set); |
98 | return i == that.i; |
99 | } |
100 | |
101 | bool operator!=(const IndexSetIterator& that) const { |
102 | ICHECK(set == that.set); |
103 | return i != that.i; |
104 | } |
105 | }; |
106 | |
107 | IndexSetIterator begin() const { return IndexSetIterator{this, FirstInsideIndex()}; } |
108 | IndexSetIterator end() const { return IndexSetIterator{this, end_index()}; } |
109 | |
110 | private: |
111 | explicit IndexSet(std::vector<bool> bitvec) : bitvec_(std::move(bitvec)) {} |
112 | |
113 | std::vector<bool> bitvec_; |
114 | }; |
115 | |
116 | struct IndexSetEqual { |
117 | bool operator()(const IndexSet& left, const IndexSet& right) const { return left == right; } |
118 | }; |
119 | |
120 | struct IndexSetHash { |
121 | size_t operator()(const IndexSet& set) const { return set.hash(); } |
122 | }; |
123 | |
124 | } // namespace collage |
125 | } // namespace relay |
126 | } // namespace tvm |
127 | |
128 | #endif // TVM_RELAY_COLLAGE_INDEX_SET_H_ |
129 | |