1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/collage/collage_partitioner.cc |
22 | * \brief Search for an optimal partitioning of a Relay model. |
23 | */ |
24 | |
25 | #include "./collage_partitioner.h" |
26 | |
27 | #include <math.h> |
28 | #include <tvm/ir/attrs.h> |
29 | #include <tvm/ir/function.h> |
30 | #include <tvm/ir/transform.h> |
31 | #include <tvm/relay/expr.h> |
32 | #include <tvm/relay/function.h> |
33 | #include <tvm/relay/transform.h> |
34 | #include <tvm/target/target.h> |
35 | |
36 | #include "../ir/dataflow_matcher_impl.h" |
37 | #include "../transforms/compiler_function_utils.h" |
38 | #include "../transforms/device_aware_visitors.h" |
39 | #include "./candidate_partition.h" |
40 | #include "./candidate_partition_index.h" |
41 | #include "./cost.h" |
42 | #include "./cost_estimator.h" |
43 | #include "./gather_partition_specs.h" |
44 | #include "./name_supply.h" |
45 | #include "./partition_rule.h" |
46 | #include "./partition_spec.h" |
47 | #include "./priority_queue.h" |
48 | #include "./sub_graph.h" |
49 | #include "./utils.h" |
50 | |
51 | namespace tvm { |
52 | namespace relay { |
53 | namespace collage { |
54 | namespace { |
55 | |
56 | TVM_REGISTER_PASS_CONFIG_OPTION("relay.collage.tvm_max_depth" , Integer); |
57 | TVM_REGISTER_PASS_CONFIG_OPTION("relay.collage.byoc_max_depth" , Integer); |
58 | TVM_REGISTER_PASS_CONFIG_OPTION("relay.collage.byoc_fusion_style" , Array<String>); |
59 | /*! |
60 | * \brief Represents the overall expression after some number of non-overlapping candidate |
61 | * partitions have been applied. |
62 | */ |
63 | class SearchState { |
64 | public: |
65 | explicit SearchState(IndexSet covered) : covered_(std::move(covered)) {} |
66 | |
67 | /*! |
68 | * \brief Order states by increasing best cost, breaking ties by lexicographic order on |
69 | * the covering sub graph. |
70 | */ |
71 | bool operator<(const SearchState& that) const { |
72 | return std::tie(best_cost_, covered_) < std::tie(that.best_cost_, that.covered_); |
73 | } |
74 | |
75 | const IndexSet& covered() const { return covered_; } |
76 | |
77 | std::string ToString() const { |
78 | std::ostringstream os; |
79 | os << "State(" ; |
80 | os << "covered=" << covered_.ToString(); |
81 | os << ",best_cost=" << best_cost_.ToString(); |
82 | if (best_candidate_.defined()) { |
83 | os << ",best_candidate=" << best_candidate_->ToString(); |
84 | } |
85 | os << ")" ; |
86 | return os.str(); |
87 | } |
88 | |
89 | private: |
90 | /*! \brief Which nodes of overall expression have been placed on all paths to this state. */ |
91 | IndexSet covered_; |
92 | /*! \brief Predecessor state for sequence of candidates reaching this state with least |
93 | * cost. Null if initial search state. */ |
94 | SearchState* pred_state_ = nullptr; |
95 | /*! |
96 | * \brief Cost of reaching this state using placement implied by path given by pred_state fields. |
97 | * Includes estimated/measured cost of all candidates plus any candidate launch penalty. |
98 | * Initially invalid cost. |
99 | */ |
100 | Cost best_cost_ = Cost::Invalid(); |
101 | /*! \brief Candidate partition selected in transition from pred_state to this state. */ |
102 | CandidatePartition best_candidate_; |
103 | |
104 | friend class Partitioner; |
105 | }; |
106 | |
107 | struct CompareSearchStatePtrs { |
108 | bool operator()(const SearchState* left, const SearchState* right) const { |
109 | return *left < *right; |
110 | } |
111 | }; |
112 | |
113 | struct EqualSearchStatePtrs { |
114 | bool operator()(const SearchState* left, const SearchState* right) const { |
115 | return left->covered() == right->covered(); |
116 | } |
117 | }; |
118 | |
119 | /*! |
120 | * \brief Finds the optimal partitioning of an expression to candidate partitions. |
121 | * Though no candidate partitions overlap, it is possible some sub-expressions end up in |
122 | * no candidate. Those sub-expressions must be evaluated by the host executor (eg VM). |
123 | */ |
124 | class Partitioner { |
125 | public: |
126 | explicit Partitioner(Array<PartitionSpec> partition_specs, |
127 | const std::unordered_map<const ExprNode*, VirtualDevice>* virtual_devices, |
128 | CostEstimator cost_estimator, std::shared_ptr<CandidateFunctionCache> cache, |
129 | Expr expr) |
130 | : partition_specs_(std::move(partition_specs)), |
131 | virtual_devices_(virtual_devices), |
132 | cost_estimator_(std::move(cost_estimator)), |
133 | cache_(std::move(cache)), |
134 | expr_(std::move(expr)) {} |
135 | |
136 | Expr Partition() { |
137 | // Establish core data structures. |
138 | dataflow_graph_ = std::make_unique<DataflowGraph>(expr_); |
139 | VLOG(1) << "Created dataflow graph with " << dataflow_graph_->size() << " nodes" ; |
140 | |
141 | // Build the candidate index. This is where all the partition rules are invoked . |
142 | index_ = std::make_unique<CandidatePartitionIndex>(virtual_devices_, dataflow_graph_.get()); |
143 | index_->Index(partition_specs_); |
144 | VLOG(1) << "All candidates before search:" << std::endl << index_->ToSummary(); |
145 | |
146 | // 'Eagerly' estimate the cost of all candidates. |
147 | // |
148 | // Note if this is not done costs will simply be estimated 'lazily' as the search proceeds. |
149 | // Typically, some candidates are never explored during the search because: |
150 | // - There are no paths in which the candidate does not intersect candidates already |
151 | // applied on the path. |
152 | // - The Dijkstra search terminates early with a least cost path. |
153 | // So eager may result in more estimation overhead. However, eager could be made |
154 | // embarrassingly parallel. |
155 | VLOG(1) << "Beginning eager cost estimation" ; |
156 | index_->EstimateAllCosts(cost_estimator_, cache_); |
157 | VLOG(1) << "Finished eager cost estimation" ; |
158 | |
159 | // Setup initial state. |
160 | SearchState* init_state = GetState(IndexSet(dataflow_graph_->size())); |
161 | init_state->best_cost_ = Cost::Zero(); |
162 | pq_.Push(init_state); |
163 | |
164 | size_t num_transitions = 0; |
165 | |
166 | VLOG(1) << "#### Commencing Collage search over " << index_->size() << " candidates ####" ; |
167 | while (!pq_.empty()) { |
168 | SearchState* curr_state = pq_.Pop(); |
169 | VLOG(1) << "Looking at state " << curr_state->covered_.ToString(); |
170 | PostDfsIndex next_index = curr_state->covered_.FirstOutsideIndex(); |
171 | |
172 | if (next_index >= dataflow_graph_->size()) { |
173 | // The entire expression has been explored. Collect the candidates on the optimal path. |
174 | VLOG(1) << "#### Finished Collage search after exploring " << num_transitions |
175 | << " transitions ####" ; |
176 | std::vector<CandidatePartition> best_candidates; |
177 | while (curr_state != init_state) { |
178 | ICHECK(curr_state->best_candidate_.defined()); |
179 | best_candidates.emplace_back(curr_state->best_candidate_); |
180 | curr_state = curr_state->pred_state_; |
181 | ICHECK(curr_state != nullptr); |
182 | } |
183 | return Finalize(best_candidates); |
184 | } |
185 | |
186 | size_t num_fires = 0; |
187 | Expr sub_expr = dataflow_graph_->index_to_node(next_index)->ref(); |
188 | VLOG(1) << "Looking at index " << next_index << " for sub-expression " |
189 | << SubExprKindAndLabel(sub_expr).second << " out of " << dataflow_graph_->size() |
190 | << " total dataflow nodes" ; |
191 | |
192 | // Explore all the outgoing candidates from the current state. |
193 | for (const auto& candidate : index_->candidates_at(next_index)) { |
194 | VLOG(1) << "Considering candidate " << candidate->ToSummary(*dataflow_graph_) |
195 | << " for transition " << ++num_transitions << " over " << index_->size() |
196 | << " total candidates" ; |
197 | if (!candidate->sub_graph_->inside_.AreDisjoint(curr_state->covered_)) { |
198 | LOG(INFO) << "Candidate overlaps with already partitioned nodes" ; |
199 | continue; |
200 | } |
201 | IndexSet next_covered = curr_state->covered_ | candidate->sub_graph_->inside_; |
202 | SearchState* next_state = GetState(next_covered); |
203 | Relax(curr_state, next_state, candidate); |
204 | ++num_fires; |
205 | } |
206 | ICHECK_GT(num_fires, 0) |
207 | << "No candidate was found covering sub-expression at index " << next_index |
208 | << ", suggesting the partition rules are incomplete for the given targets." ; |
209 | } |
210 | |
211 | ICHECK(false) << "should have reached end state in which all sub-expressions are covered" ; |
212 | return {}; |
213 | } |
214 | |
215 | /*! \brief Returns the unique state corresponding to the \p covered sub-graph. */ |
216 | SearchState* GetState(const IndexSet& covered) { |
217 | auto itr = covered_to_state_.find(covered); |
218 | if (itr != covered_to_state_.end()) { |
219 | return itr->second.get(); |
220 | } |
221 | auto state = std::make_unique<SearchState>(covered); |
222 | SearchState* raw_ptr = state.get(); |
223 | covered_to_state_.emplace(covered, std::move(state)); |
224 | return raw_ptr; |
225 | } |
226 | |
227 | /*! |
228 | * \brief Record that it is possible to reach \p next_state by choosing \p candidate |
229 | * in \p curr_state. If the resulting cost is better than the best known so far, update |
230 | * \p next_state's best cost, predecessor and candidate to match. |
231 | */ |
232 | void Relax(SearchState* curr_state, SearchState* next_state, |
233 | const CandidatePartition& candidate) { |
234 | // Note this may already be cached if the candidate partition costs were 'eagerly' estimated. |
235 | Cost candidate_cost = candidate->EstimatedCost(*dataflow_graph_, cost_estimator_, cache_); |
236 | VLOG(1) << "Candidate has cost " << candidate_cost.ToString(); |
237 | Cost new_state_cost = candidate_cost + curr_state->best_cost_; |
238 | const bool is_new = next_state->best_cost_.is_invalid(); |
239 | CandidatePartition previously_best_candidate = next_state->best_candidate_; |
240 | if (is_new || new_state_cost < next_state->best_cost_) { |
241 | next_state->pred_state_ = curr_state; |
242 | Cost previously_best_cost = next_state->best_cost_; |
243 | next_state->best_cost_ = new_state_cost; |
244 | next_state->best_candidate_ = candidate; |
245 | if (is_new) { |
246 | VLOG(1) << "transition " << curr_state->ToString() << " --> " << next_state->ToString() |
247 | << " (New state for spec " << candidate->partition_spec_name() << ")" ; |
248 | pq_.Push(next_state); |
249 | } else { |
250 | VLOG(1) << "transition " << curr_state->ToString() << " --> " << next_state->ToString() |
251 | << " (Spec " << candidate->partition_spec_name() << " beats previous spec " |
252 | << previously_best_candidate->partition_spec_name() << " by " |
253 | << (previously_best_cost - curr_state->best_cost_).ToString() << ")" ; |
254 | pq_.Update(next_state); |
255 | } |
256 | } else { |
257 | VLOG(1) << "transition " << curr_state->ToString() << " --> " << next_state->ToString() |
258 | << " (Spec " << candidate->partition_spec_name() << " does not beat existing spec " |
259 | << previously_best_candidate->partition_spec_name() << ")" ; |
260 | } |
261 | } |
262 | |
263 | /*! |
264 | * \brief Returns the result of partitioning \p expr according to 'optimal' candidates found |
265 | * by the search. |
266 | */ |
267 | Expr Finalize(std::vector<CandidatePartition> best_candidates) { |
268 | best_candidates = CandidatePartition::MaxCoalesce(*dataflow_graph_, best_candidates); |
269 | |
270 | Cost total_cost = Cost::Zero(); |
271 | std::ostringstream os; |
272 | os << "Optimal partitioning:" << std::endl; |
273 | for (const auto& best_candidate : best_candidates) { |
274 | if (best_candidate->partition_spec_name() == kHostSpecName) { |
275 | continue; |
276 | } |
277 | os << best_candidate->ToSummary(*dataflow_graph_); |
278 | os << std::endl; |
279 | total_cost = total_cost + best_candidate->cost_; |
280 | } |
281 | os << "Estimated overall cost is " << total_cost.ToString(); |
282 | LOG(INFO) << os.str(); |
283 | |
284 | LOG(INFO) << "All candidates after search:" << std::endl << index_->ToSummary(); |
285 | |
286 | return CandidatePartition::ParallelRewrite(*dataflow_graph_, best_candidates); |
287 | } |
288 | |
289 | private: |
290 | /*! \brief Available partition specs to use during search. */ |
291 | Array<PartitionSpec> partition_specs_; |
292 | /*! |
293 | * \brief The virtual devices for every sub-expression so we can respect any existing target |
294 | * constraints. |
295 | */ |
296 | const std::unordered_map<const ExprNode*, VirtualDevice>* virtual_devices_; |
297 | /*! \brief Cost estimator to use for candidates. */ |
298 | CostEstimator cost_estimator_; |
299 | /*! \brief Cached names and costs for all partition functions. */ |
300 | std::shared_ptr<CandidateFunctionCache> cache_; |
301 | /*! \brief The expression we will be partitioning. */ |
302 | Expr expr_; |
303 | /*! \brief Dataflow graph for overall expression. */ |
304 | std::unique_ptr<DataflowGraph> dataflow_graph_; |
305 | /*! \brief Index of all avoilable candidates we are searching over. */ |
306 | std::unique_ptr<CandidatePartitionIndex> index_; |
307 | /*! \brief Map from covered sub-graphs to the corresponding state. */ |
308 | std::unordered_map<IndexSet, std::unique_ptr<SearchState>, IndexSetHash, IndexSetEqual> |
309 | covered_to_state_; |
310 | /*! \brief Priority queue of states, ordered by increasing cost. */ |
311 | PriorityQueue<SearchState, CompareSearchStatePtrs, EqualSearchStatePtrs> pq_; |
312 | }; |
313 | |
314 | } // namespace |
315 | |
316 | transform::Pass CollagePartition(CompilationConfig config, CostEstimator cost_estimator) { |
317 | runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func = |
318 | [config = std::move(config), cost_estimator = std::move(cost_estimator)]( |
319 | IRModule mod, transform::PassContext ctxt) { |
320 | VLOG(1) << "CollagePartition input:" << std::endl << PrettyPrint(mod); |
321 | |
322 | Array<PartitionSpec> partition_specs = GatherPartitionSpecs(config); |
323 | VLOG(1) << "Gathered " << partition_specs.size() << " partition specs" ; |
324 | |
325 | auto cache = |
326 | std::make_shared<CandidateFunctionCache>(std::make_shared<NameSupply>("collage" )); |
327 | |
328 | IRModule out_mod = mod->ShallowCopy(); |
329 | for (const auto& kv : mod->functions) { |
330 | if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) { |
331 | auto function = GetRef<Function>(function_node); |
332 | std::unordered_map<const ExprNode*, VirtualDevice> virtual_devices = |
333 | transform::RecoverVirtualDeviceMap(mod, function); |
334 | Partitioner partitioner(partition_specs, &virtual_devices, cost_estimator, cache, |
335 | function); |
336 | Function result = Downcast<Function>(partitioner.Partition()); |
337 | out_mod->Add(kv.first, result); |
338 | } |
339 | } |
340 | |
341 | out_mod = OutlineCompilerFunctions(cache)(std::move(out_mod)); |
342 | VLOG(1) << "CollagePartition result:" << std::endl << PrettyPrint(out_mod); |
343 | return out_mod; |
344 | }; |
345 | return tvm::transform::CreateModulePass(pass_func, /*opt_level=*/0, "CollagePartition" , {}); |
346 | } |
347 | |
348 | TVM_REGISTER_GLOBAL("relay._transform.CollagePartition" ).set_body_typed(CollagePartition); |
349 | |
350 | } // namespace collage |
351 | } // namespace relay |
352 | } // namespace tvm |
353 | |