1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/collage/collage_partitioner.cc
22 * \brief Search for an optimal partitioning of a Relay model.
23 */
24
25#include "./collage_partitioner.h"
26
27#include <math.h>
28#include <tvm/ir/attrs.h>
29#include <tvm/ir/function.h>
30#include <tvm/ir/transform.h>
31#include <tvm/relay/expr.h>
32#include <tvm/relay/function.h>
33#include <tvm/relay/transform.h>
34#include <tvm/target/target.h>
35
36#include "../ir/dataflow_matcher_impl.h"
37#include "../transforms/compiler_function_utils.h"
38#include "../transforms/device_aware_visitors.h"
39#include "./candidate_partition.h"
40#include "./candidate_partition_index.h"
41#include "./cost.h"
42#include "./cost_estimator.h"
43#include "./gather_partition_specs.h"
44#include "./name_supply.h"
45#include "./partition_rule.h"
46#include "./partition_spec.h"
47#include "./priority_queue.h"
48#include "./sub_graph.h"
49#include "./utils.h"
50
51namespace tvm {
52namespace relay {
53namespace collage {
54namespace {
55
56TVM_REGISTER_PASS_CONFIG_OPTION("relay.collage.tvm_max_depth", Integer);
57TVM_REGISTER_PASS_CONFIG_OPTION("relay.collage.byoc_max_depth", Integer);
58TVM_REGISTER_PASS_CONFIG_OPTION("relay.collage.byoc_fusion_style", Array<String>);
59/*!
60 * \brief Represents the overall expression after some number of non-overlapping candidate
61 * partitions have been applied.
62 */
63class SearchState {
64 public:
65 explicit SearchState(IndexSet covered) : covered_(std::move(covered)) {}
66
67 /*!
68 * \brief Order states by increasing best cost, breaking ties by lexicographic order on
69 * the covering sub graph.
70 */
71 bool operator<(const SearchState& that) const {
72 return std::tie(best_cost_, covered_) < std::tie(that.best_cost_, that.covered_);
73 }
74
75 const IndexSet& covered() const { return covered_; }
76
77 std::string ToString() const {
78 std::ostringstream os;
79 os << "State(";
80 os << "covered=" << covered_.ToString();
81 os << ",best_cost=" << best_cost_.ToString();
82 if (best_candidate_.defined()) {
83 os << ",best_candidate=" << best_candidate_->ToString();
84 }
85 os << ")";
86 return os.str();
87 }
88
89 private:
90 /*! \brief Which nodes of overall expression have been placed on all paths to this state. */
91 IndexSet covered_;
92 /*! \brief Predecessor state for sequence of candidates reaching this state with least
93 * cost. Null if initial search state. */
94 SearchState* pred_state_ = nullptr;
95 /*!
96 * \brief Cost of reaching this state using placement implied by path given by pred_state fields.
97 * Includes estimated/measured cost of all candidates plus any candidate launch penalty.
98 * Initially invalid cost.
99 */
100 Cost best_cost_ = Cost::Invalid();
101 /*! \brief Candidate partition selected in transition from pred_state to this state. */
102 CandidatePartition best_candidate_;
103
104 friend class Partitioner;
105};
106
107struct CompareSearchStatePtrs {
108 bool operator()(const SearchState* left, const SearchState* right) const {
109 return *left < *right;
110 }
111};
112
113struct EqualSearchStatePtrs {
114 bool operator()(const SearchState* left, const SearchState* right) const {
115 return left->covered() == right->covered();
116 }
117};
118
119/*!
120 * \brief Finds the optimal partitioning of an expression to candidate partitions.
121 * Though no candidate partitions overlap, it is possible some sub-expressions end up in
122 * no candidate. Those sub-expressions must be evaluated by the host executor (eg VM).
123 */
124class Partitioner {
125 public:
126 explicit Partitioner(Array<PartitionSpec> partition_specs,
127 const std::unordered_map<const ExprNode*, VirtualDevice>* virtual_devices,
128 CostEstimator cost_estimator, std::shared_ptr<CandidateFunctionCache> cache,
129 Expr expr)
130 : partition_specs_(std::move(partition_specs)),
131 virtual_devices_(virtual_devices),
132 cost_estimator_(std::move(cost_estimator)),
133 cache_(std::move(cache)),
134 expr_(std::move(expr)) {}
135
136 Expr Partition() {
137 // Establish core data structures.
138 dataflow_graph_ = std::make_unique<DataflowGraph>(expr_);
139 VLOG(1) << "Created dataflow graph with " << dataflow_graph_->size() << " nodes";
140
141 // Build the candidate index. This is where all the partition rules are invoked .
142 index_ = std::make_unique<CandidatePartitionIndex>(virtual_devices_, dataflow_graph_.get());
143 index_->Index(partition_specs_);
144 VLOG(1) << "All candidates before search:" << std::endl << index_->ToSummary();
145
146 // 'Eagerly' estimate the cost of all candidates.
147 //
148 // Note if this is not done costs will simply be estimated 'lazily' as the search proceeds.
149 // Typically, some candidates are never explored during the search because:
150 // - There are no paths in which the candidate does not intersect candidates already
151 // applied on the path.
152 // - The Dijkstra search terminates early with a least cost path.
153 // So eager may result in more estimation overhead. However, eager could be made
154 // embarrassingly parallel.
155 VLOG(1) << "Beginning eager cost estimation";
156 index_->EstimateAllCosts(cost_estimator_, cache_);
157 VLOG(1) << "Finished eager cost estimation";
158
159 // Setup initial state.
160 SearchState* init_state = GetState(IndexSet(dataflow_graph_->size()));
161 init_state->best_cost_ = Cost::Zero();
162 pq_.Push(init_state);
163
164 size_t num_transitions = 0;
165
166 VLOG(1) << "#### Commencing Collage search over " << index_->size() << " candidates ####";
167 while (!pq_.empty()) {
168 SearchState* curr_state = pq_.Pop();
169 VLOG(1) << "Looking at state " << curr_state->covered_.ToString();
170 PostDfsIndex next_index = curr_state->covered_.FirstOutsideIndex();
171
172 if (next_index >= dataflow_graph_->size()) {
173 // The entire expression has been explored. Collect the candidates on the optimal path.
174 VLOG(1) << "#### Finished Collage search after exploring " << num_transitions
175 << " transitions ####";
176 std::vector<CandidatePartition> best_candidates;
177 while (curr_state != init_state) {
178 ICHECK(curr_state->best_candidate_.defined());
179 best_candidates.emplace_back(curr_state->best_candidate_);
180 curr_state = curr_state->pred_state_;
181 ICHECK(curr_state != nullptr);
182 }
183 return Finalize(best_candidates);
184 }
185
186 size_t num_fires = 0;
187 Expr sub_expr = dataflow_graph_->index_to_node(next_index)->ref();
188 VLOG(1) << "Looking at index " << next_index << " for sub-expression "
189 << SubExprKindAndLabel(sub_expr).second << " out of " << dataflow_graph_->size()
190 << " total dataflow nodes";
191
192 // Explore all the outgoing candidates from the current state.
193 for (const auto& candidate : index_->candidates_at(next_index)) {
194 VLOG(1) << "Considering candidate " << candidate->ToSummary(*dataflow_graph_)
195 << " for transition " << ++num_transitions << " over " << index_->size()
196 << " total candidates";
197 if (!candidate->sub_graph_->inside_.AreDisjoint(curr_state->covered_)) {
198 LOG(INFO) << "Candidate overlaps with already partitioned nodes";
199 continue;
200 }
201 IndexSet next_covered = curr_state->covered_ | candidate->sub_graph_->inside_;
202 SearchState* next_state = GetState(next_covered);
203 Relax(curr_state, next_state, candidate);
204 ++num_fires;
205 }
206 ICHECK_GT(num_fires, 0)
207 << "No candidate was found covering sub-expression at index " << next_index
208 << ", suggesting the partition rules are incomplete for the given targets.";
209 }
210
211 ICHECK(false) << "should have reached end state in which all sub-expressions are covered";
212 return {};
213 }
214
215 /*! \brief Returns the unique state corresponding to the \p covered sub-graph. */
216 SearchState* GetState(const IndexSet& covered) {
217 auto itr = covered_to_state_.find(covered);
218 if (itr != covered_to_state_.end()) {
219 return itr->second.get();
220 }
221 auto state = std::make_unique<SearchState>(covered);
222 SearchState* raw_ptr = state.get();
223 covered_to_state_.emplace(covered, std::move(state));
224 return raw_ptr;
225 }
226
227 /*!
228 * \brief Record that it is possible to reach \p next_state by choosing \p candidate
229 * in \p curr_state. If the resulting cost is better than the best known so far, update
230 * \p next_state's best cost, predecessor and candidate to match.
231 */
232 void Relax(SearchState* curr_state, SearchState* next_state,
233 const CandidatePartition& candidate) {
234 // Note this may already be cached if the candidate partition costs were 'eagerly' estimated.
235 Cost candidate_cost = candidate->EstimatedCost(*dataflow_graph_, cost_estimator_, cache_);
236 VLOG(1) << "Candidate has cost " << candidate_cost.ToString();
237 Cost new_state_cost = candidate_cost + curr_state->best_cost_;
238 const bool is_new = next_state->best_cost_.is_invalid();
239 CandidatePartition previously_best_candidate = next_state->best_candidate_;
240 if (is_new || new_state_cost < next_state->best_cost_) {
241 next_state->pred_state_ = curr_state;
242 Cost previously_best_cost = next_state->best_cost_;
243 next_state->best_cost_ = new_state_cost;
244 next_state->best_candidate_ = candidate;
245 if (is_new) {
246 VLOG(1) << "transition " << curr_state->ToString() << " --> " << next_state->ToString()
247 << " (New state for spec " << candidate->partition_spec_name() << ")";
248 pq_.Push(next_state);
249 } else {
250 VLOG(1) << "transition " << curr_state->ToString() << " --> " << next_state->ToString()
251 << " (Spec " << candidate->partition_spec_name() << " beats previous spec "
252 << previously_best_candidate->partition_spec_name() << " by "
253 << (previously_best_cost - curr_state->best_cost_).ToString() << ")";
254 pq_.Update(next_state);
255 }
256 } else {
257 VLOG(1) << "transition " << curr_state->ToString() << " --> " << next_state->ToString()
258 << " (Spec " << candidate->partition_spec_name() << " does not beat existing spec "
259 << previously_best_candidate->partition_spec_name() << ")";
260 }
261 }
262
263 /*!
264 * \brief Returns the result of partitioning \p expr according to 'optimal' candidates found
265 * by the search.
266 */
267 Expr Finalize(std::vector<CandidatePartition> best_candidates) {
268 best_candidates = CandidatePartition::MaxCoalesce(*dataflow_graph_, best_candidates);
269
270 Cost total_cost = Cost::Zero();
271 std::ostringstream os;
272 os << "Optimal partitioning:" << std::endl;
273 for (const auto& best_candidate : best_candidates) {
274 if (best_candidate->partition_spec_name() == kHostSpecName) {
275 continue;
276 }
277 os << best_candidate->ToSummary(*dataflow_graph_);
278 os << std::endl;
279 total_cost = total_cost + best_candidate->cost_;
280 }
281 os << "Estimated overall cost is " << total_cost.ToString();
282 LOG(INFO) << os.str();
283
284 LOG(INFO) << "All candidates after search:" << std::endl << index_->ToSummary();
285
286 return CandidatePartition::ParallelRewrite(*dataflow_graph_, best_candidates);
287 }
288
289 private:
290 /*! \brief Available partition specs to use during search. */
291 Array<PartitionSpec> partition_specs_;
292 /*!
293 * \brief The virtual devices for every sub-expression so we can respect any existing target
294 * constraints.
295 */
296 const std::unordered_map<const ExprNode*, VirtualDevice>* virtual_devices_;
297 /*! \brief Cost estimator to use for candidates. */
298 CostEstimator cost_estimator_;
299 /*! \brief Cached names and costs for all partition functions. */
300 std::shared_ptr<CandidateFunctionCache> cache_;
301 /*! \brief The expression we will be partitioning. */
302 Expr expr_;
303 /*! \brief Dataflow graph for overall expression. */
304 std::unique_ptr<DataflowGraph> dataflow_graph_;
305 /*! \brief Index of all avoilable candidates we are searching over. */
306 std::unique_ptr<CandidatePartitionIndex> index_;
307 /*! \brief Map from covered sub-graphs to the corresponding state. */
308 std::unordered_map<IndexSet, std::unique_ptr<SearchState>, IndexSetHash, IndexSetEqual>
309 covered_to_state_;
310 /*! \brief Priority queue of states, ordered by increasing cost. */
311 PriorityQueue<SearchState, CompareSearchStatePtrs, EqualSearchStatePtrs> pq_;
312};
313
314} // namespace
315
316transform::Pass CollagePartition(CompilationConfig config, CostEstimator cost_estimator) {
317 runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func =
318 [config = std::move(config), cost_estimator = std::move(cost_estimator)](
319 IRModule mod, transform::PassContext ctxt) {
320 VLOG(1) << "CollagePartition input:" << std::endl << PrettyPrint(mod);
321
322 Array<PartitionSpec> partition_specs = GatherPartitionSpecs(config);
323 VLOG(1) << "Gathered " << partition_specs.size() << " partition specs";
324
325 auto cache =
326 std::make_shared<CandidateFunctionCache>(std::make_shared<NameSupply>("collage"));
327
328 IRModule out_mod = mod->ShallowCopy();
329 for (const auto& kv : mod->functions) {
330 if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) {
331 auto function = GetRef<Function>(function_node);
332 std::unordered_map<const ExprNode*, VirtualDevice> virtual_devices =
333 transform::RecoverVirtualDeviceMap(mod, function);
334 Partitioner partitioner(partition_specs, &virtual_devices, cost_estimator, cache,
335 function);
336 Function result = Downcast<Function>(partitioner.Partition());
337 out_mod->Add(kv.first, result);
338 }
339 }
340
341 out_mod = OutlineCompilerFunctions(cache)(std::move(out_mod));
342 VLOG(1) << "CollagePartition result:" << std::endl << PrettyPrint(out_mod);
343 return out_mod;
344 };
345 return tvm::transform::CreateModulePass(pass_func, /*opt_level=*/0, "CollagePartition", {});
346}
347
348TVM_REGISTER_GLOBAL("relay._transform.CollagePartition").set_body_typed(CollagePartition);
349
350} // namespace collage
351} // namespace relay
352} // namespace tvm
353