1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/collage/candidate_set.h
22 * \brief Collects a set of candidate partitions.
23 */
24
25#ifndef TVM_RELAY_COLLAGE_CANDIDATE_SET_H_
26#define TVM_RELAY_COLLAGE_CANDIDATE_SET_H_
27
28#include <algorithm>
29#include <unordered_set>
30#include <utility>
31#include <vector>
32
33#include "./candidate_partition.h"
34#include "./dataflow_graph.h"
35
36namespace tvm {
37namespace relay {
38namespace collage {
39
40/*!
41 * \brief Holds a vector of current candidates and the additions/removals to apply to them.
42 */
43struct CandidateSet {
44 CandidateSet() = default;
45
46 explicit CandidateSet(std::vector<CandidatePartition> candidates_to_add);
47
48 /*!
49 * \brief Schedule \p new_candidate for addition before the next round (unless it is not valid).
50 */
51 void Add(const DataflowGraph& dataflow_graph, const CandidatePartition& new_candidate);
52
53 /*! \brief Schedule \p old_candidate for removal before the next round. */
54 void Remove(const CandidatePartition& old_candidate);
55
56 /*!
57 * \brief Update \p current_candidates and \p first_new_index. Return false if no
58 * new candidates were added, in which case we have reached a fixed point.
59 */
60 bool PrepareForNextRound();
61
62 size_t size() const { return current_candidates_.size(); }
63
64 CandidatePartition operator[](size_t i) const {
65 ICHECK_LT(i, current_candidates_.size());
66 return current_candidates_[i];
67 }
68 CandidatePartition at(size_t i) const { return (*this)[i]; }
69
70 size_t first_new_index() const { return first_new_index_; }
71
72 void sort() { std::sort(current_candidates_.begin(), current_candidates_.end()); }
73
74 std::vector<CandidatePartition> MovedCurrentCandidates() {
75 return std::move(current_candidates_);
76 }
77
78 private:
79 /*!
80 * \brief Index of first candidate in current_candidates added in last round. This can be used to
81 * avoid considering candidates or candidate combinations which have already been considered in an
82 * earlier round.
83 */
84 size_t first_new_index_ = 0;
85 /*! \brief Candidates gathered in previous rounds. */
86 std::vector<CandidatePartition> current_candidates_;
87 /*! \brief New candidates gathered in the current round. */
88 std::vector<CandidatePartition> candidates_to_add_;
89 /*! \brief Existing candidates to remove before starting the next round. */
90 std::vector<CandidatePartition> candidates_to_remove_;
91 /*! \brief Which candidates have been seen so far and should not be added again. */
92 std::unordered_set<CandidatePartition, CandidatePartitionHash, CandidatePartitionEquals> seen_;
93};
94
95} // namespace collage
96} // namespace relay
97} // namespace tvm
98
99#endif // TVM_RELAY_COLLAGE_CANDIDATE_SET_H_
100