1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/collage/prune_candidates.cc
22 * \brief Try to remove candidates which will never contribute to an optimal partitioning.
23 */
24
25#include "./prune_candidates.h"
26
27#include "./dataflow_graph.h"
28#include "./gather_partition_specs.h"
29
30namespace tvm {
31namespace relay {
32namespace collage {
33
34namespace {
35
36/*!
37 * \brief Returns a map from post-dfs dataflow node indices to the indices within \p candidates for
38 * those candidates which intersect that dataflow node.
39 *
40 * NOTE: The index set in the vector results is over candidate indices not post-dfs indices!
41 */
42std::vector<IndexSet> MakeInsideMap(const DataflowGraph& dataflow_graph,
43 const std::vector<CandidatePartition>& candidates) {
44 std::vector<IndexSet> result(dataflow_graph.size(), IndexSet(candidates.size()));
45 for (size_t i = 0; i < candidates.size(); ++i) {
46 CandidatePartition candidate = candidates[i];
47 for (PostDfsIndex index : candidate->sub_graph_->inside_) {
48 result[index].Add(i);
49 }
50 }
51 return result;
52}
53
54/*!
55 * \brief Returns the maximal candidates within \p candidates. A candidate is maximal if it is not
56 * contained by any super-candidate for the same target.
57 */
58std::vector<CandidatePartition> MaximalCandidates(
59 const DataflowGraph& dataflow_graph, const std::vector<CandidatePartition>& candidates) {
60 std::vector<IndexSet> inside_map = MakeInsideMap(dataflow_graph, candidates);
61 std::vector<CandidatePartition> result;
62 for (size_t i = 0; i < candidates.size(); ++i) {
63 CandidatePartition maximal_candidate = candidates[i];
64 bool has_super_candidate = false;
65 IndexSet explored_candidates(candidates.size()); // over candidates!
66 for (PostDfsIndex index : maximal_candidate->sub_graph_->inside_) {
67 for (size_t j : inside_map[index]) {
68 if (i == j) {
69 // Ignore self.
70 continue;
71 }
72 if (explored_candidates[j]) {
73 // Already checked.
74 continue;
75 }
76 explored_candidates.Add(j);
77 CandidatePartition super_candidate = candidates[j];
78 if (maximal_candidate->spec_ == super_candidate->spec_ &&
79 maximal_candidate->sub_graph_->inside_.IsSubset(super_candidate->sub_graph_->inside_)) {
80 has_super_candidate = true;
81 break;
82 }
83 }
84 if (has_super_candidate) {
85 break;
86 }
87 }
88 if (!has_super_candidate) {
89 VLOG(2) << "Found maximal candidate " << maximal_candidate->ToString();
90 result.emplace_back(maximal_candidate);
91 }
92 }
93 VLOG(1) << "Have " << result.size() << " maximal candidates";
94 return result;
95}
96
97/*!
98 * \brief Returns all the candidates in \p candidates which intersect without being equal.
99 */
100std::vector<CandidatePartition> IntersectingCandidates(
101 const DataflowGraph& dataflow_graph, const std::vector<CandidatePartition>& candidates) {
102 std::vector<IndexSet> inside_map = MakeInsideMap(dataflow_graph, candidates);
103 IndexSet intersecting(candidates.size()); // over candidates!
104 for (size_t i = 0; i < candidates.size(); ++i) {
105 CandidatePartition intersecting_candidate = candidates[i];
106 IndexSet explored_candidates(candidates.size()); // over candidates!
107 for (PostDfsIndex index : intersecting_candidate->sub_graph_->inside_) {
108 for (size_t j : inside_map[index]) {
109 if (j < i) {
110 // Intersection is commutative.
111 continue;
112 }
113 if (i == j) {
114 // Ignore self.
115 continue;
116 }
117 if (explored_candidates[j]) {
118 // Already checked.
119 continue;
120 }
121 explored_candidates.Add(j);
122 CandidatePartition other_candidate = candidates[j];
123 if (intersecting_candidate->sub_graph_->inside_ == other_candidate->sub_graph_->inside_) {
124 // Have same inside set.
125 continue;
126 }
127 VLOG(2) << "Candidate " << intersecting_candidate->ToString() << " intersects with "
128 << other_candidate->ToString();
129 intersecting.Add(i);
130 intersecting.Add(j);
131 }
132 }
133 }
134 std::vector<CandidatePartition> result;
135 for (size_t i : intersecting) {
136 CandidatePartition candidate = candidates[i];
137 VLOG(2) << "Found intersecting candidate " << candidate->ToString();
138 result.emplace_back(candidate);
139 }
140 VLOG(1) << "Have " << result.size() << " intersecting candidates";
141 return result;
142}
143
144/*!
145 * \brief Returns the set operation left - right.
146 */
147std::vector<CandidatePartition> SetDifference(const std::vector<CandidatePartition>& left,
148 const std::vector<CandidatePartition>& right) {
149 std::unordered_set<CandidatePartition, CandidatePartitionHash, CandidatePartitionEquals>
150 right_set(right.begin(), right.end());
151 std::vector<CandidatePartition> result;
152 for (const auto& candidate : left) {
153 if (right_set.count(candidate) == 0) {
154 result.emplace_back(candidate);
155 }
156 }
157 return result;
158}
159
160/*!
161 * \brief Adds everything in right to left. Returns the number of elements added.
162 */
163size_t SetUnionInPlace(
164 std::unordered_set<CandidatePartition, CandidatePartitionHash, CandidatePartitionEquals>* left,
165 const std::vector<CandidatePartition>& right) {
166 size_t init_size = left->size();
167 for (const auto& candidate : right) {
168 left->emplace(candidate);
169 }
170 return left->size() - init_size;
171}
172
173} // namespace
174
175std::vector<CandidatePartition> PruneCandidates(
176 const DataflowGraph& dataflow_graph,
177 const std::vector<CandidatePartition>& initial_candidates) {
178 VLOG_CONTEXT << "prune";
179 // Start with all candidates available.
180 std::vector<CandidatePartition> candidates = initial_candidates;
181 std::unordered_set<CandidatePartition, CandidatePartitionHash, CandidatePartitionEquals> pruned;
182 size_t initial_num_candidates = candidates.size();
183 size_t num_rounds = 0;
184 while (true) {
185 VLOG_CONTEXT << "round " << ++num_rounds;
186 VLOG(1) << "checking " << candidates.size() << " candidates";
187 // Add all the maximal candidates to the pruned set.
188 std::vector<CandidatePartition> maximal_candidates =
189 MaximalCandidates(dataflow_graph, candidates);
190 size_t num_new_pruned = SetUnionInPlace(&pruned, maximal_candidates);
191 VLOG(1) << "Added " << num_new_pruned << " new pruned candidates";
192 if (num_new_pruned == 0) {
193 // We've reached a fixed point.
194 break;
195 }
196 // If two pruned candidates intersect without being equal then we may miss valid
197 // paths during search. So remove those intersecting candidates from the available candidates
198 // and try again so as to find smaller candidates to 'bridge the gaps'.
199 std::vector<CandidatePartition> pruned_vec(pruned.begin(), pruned.end());
200 std::vector<CandidatePartition> intersecting_candidates =
201 IntersectingCandidates(dataflow_graph, pruned_vec);
202 // We need more maximal candidates to fill in the gaps between the current pruned candidates.
203 // Force that by removing the intersecting candidates from the set of available candidates
204 // and going around again.
205 candidates = SetDifference(candidates, intersecting_candidates);
206 }
207
208 std::vector<CandidatePartition> result(pruned.begin(), pruned.end());
209 // Re-establish a canonical order of candidates.
210 std::sort(result.begin(), result.end());
211 VLOG(1) << "Pruned " << initial_num_candidates - result.size() << " candidates (ie from "
212 << initial_num_candidates << " to " << result.size() << ")";
213 return result;
214}
215
216} // namespace collage
217} // namespace relay
218} // namespace tvm
219