1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/collage/index_set.h
22 * \brief Efficient representation of a set of post-dfs indexes.
23 */
24
25#ifndef TVM_RELAY_COLLAGE_INDEX_SET_H_
26#define TVM_RELAY_COLLAGE_INDEX_SET_H_
27
28#include <string>
29#include <unordered_map>
30#include <utility>
31#include <vector>
32
33#include "../ir/dataflow_matcher_impl.h"
34#include "../ir/indexed_graph.h"
35
36namespace tvm {
37namespace relay {
38namespace collage {
39
40using IndexSubst = std::unordered_map<size_t, size_t>;
41
42class IndexSet {
43 public:
44 IndexSet() = default;
45 explicit IndexSet(size_t size) : bitvec_(size, false) {}
46 IndexSet(size_t size, const std::vector<size_t>& indexes);
47
48 IndexSet operator&(const IndexSet& that) const;
49 IndexSet operator|(const IndexSet& that) const;
50 IndexSet operator-(const IndexSet& that) const;
51 bool AreDisjoint(const IndexSet& that) const;
52 bool IsSubset(const IndexSet& that) const;
53 bool Intersects(const IndexSet& that) const;
54
55 bool operator[](size_t index) const {
56 ICHECK_LT(index, bitvec_.size());
57 return bitvec_[index];
58 }
59
60 IndexSet& Add(size_t index) {
61 ICHECK_LT(index, bitvec_.size());
62 bitvec_[index] = true;
63 return *this;
64 }
65
66 IndexSet Subst(size_t new_size, const IndexSubst& subst) const;
67
68 size_t end_index() const { return bitvec_.size(); }
69 size_t PopCount() const;
70 bool IsZero() const;
71 size_t FirstInsideIndex() const;
72 size_t LastInsideIndex() const;
73 size_t NextIndex(size_t index) const;
74 size_t FirstOutsideIndex() const;
75 bool operator==(const IndexSet& that) const;
76 bool operator!=(const IndexSet& that) const;
77 bool operator<(const IndexSet& that) const;
78 size_t hash() const;
79 std::string ToString() const;
80
81 struct IndexSetIterator {
82 const IndexSet* set;
83 size_t i;
84
85 size_t operator*() const {
86 ICHECK_LT(i, set->end_index());
87 return i;
88 }
89
90 const IndexSetIterator& operator++() {
91 ICHECK_LT(i, set->end_index());
92 i = set->NextIndex(i);
93 return *this;
94 }
95
96 bool operator==(const IndexSetIterator& that) const {
97 ICHECK(set == that.set);
98 return i == that.i;
99 }
100
101 bool operator!=(const IndexSetIterator& that) const {
102 ICHECK(set == that.set);
103 return i != that.i;
104 }
105 };
106
107 IndexSetIterator begin() const { return IndexSetIterator{this, FirstInsideIndex()}; }
108 IndexSetIterator end() const { return IndexSetIterator{this, end_index()}; }
109
110 private:
111 explicit IndexSet(std::vector<bool> bitvec) : bitvec_(std::move(bitvec)) {}
112
113 std::vector<bool> bitvec_;
114};
115
116struct IndexSetEqual {
117 bool operator()(const IndexSet& left, const IndexSet& right) const { return left == right; }
118};
119
120struct IndexSetHash {
121 size_t operator()(const IndexSet& set) const { return set.hash(); }
122};
123
124} // namespace collage
125} // namespace relay
126} // namespace tvm
127
128#endif // TVM_RELAY_COLLAGE_INDEX_SET_H_
129