1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/collage/candidate_set.cc
22 * \brief Collects a set of candidate partitions.
23 */
24
25#include "./candidate_set.h"
26
27namespace tvm {
28namespace relay {
29namespace collage {
30
31CandidateSet::CandidateSet(std::vector<CandidatePartition> candidates_to_add)
32 : candidates_to_add_(std::move(candidates_to_add)) {
33 for (const auto& candidate : candidates_to_add_) {
34 seen_.emplace(candidate);
35 }
36}
37
38void CandidateSet::Add(const DataflowGraph& dataflow_graph,
39 const CandidatePartition& new_candidate) {
40 VLOG(2) << "adding " << new_candidate->ToString();
41 if (seen_.count(new_candidate)) {
42 VLOG(2) << "already seen candidate, ignoring";
43 return;
44 }
45 seen_.emplace(new_candidate);
46 candidates_to_add_.emplace_back(new_candidate);
47}
48
49void CandidateSet::Remove(const CandidatePartition& old_candidate) {
50 ICHECK(seen_.count(old_candidate));
51 VLOG(2) << "removing " << old_candidate->ToString();
52 candidates_to_remove_.emplace_back(old_candidate);
53}
54
55bool CandidateSet::PrepareForNextRound() {
56 size_t init_size = current_candidates_.size();
57 for (const auto& candidate_to_remove : candidates_to_remove_) {
58 current_candidates_.erase(
59 std::remove(current_candidates_.begin(), current_candidates_.end(), candidate_to_remove),
60 current_candidates_.end());
61 }
62 size_t num_removed = init_size - current_candidates_.size();
63 candidates_to_remove_.clear();
64 first_new_index_ = current_candidates_.size();
65 for (const auto& new_candidate : candidates_to_add_) {
66 current_candidates_.push_back(new_candidate);
67 }
68 size_t num_added = candidates_to_add_.size();
69 candidates_to_add_.clear();
70 VLOG(1) << "removed " << num_removed << " and added " << num_added << " candidates";
71 return num_removed + num_added > 0;
72}
73
74} // namespace collage
75} // namespace relay
76} // namespace tvm
77