1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/collage/partition_rule.cc |
22 | * \brief Compositional partitioning rules. |
23 | */ |
24 | |
25 | #include "./partition_rule.h" |
26 | |
27 | #include <tvm/relay/transform.h> |
28 | |
29 | #include "./partition_rule.h" |
30 | #include "./partition_spec.h" |
31 | #include "./utils.h" |
32 | |
33 | namespace tvm { |
34 | namespace relay { |
35 | namespace collage { |
36 | |
37 | TVM_REGISTER_NODE_TYPE(PartitionRuleNode); |
38 | |
39 | void PartitionRuleNode::VisitAttrs(AttrVisitor* v) { |
40 | // TODO(mbs) |
41 | } |
42 | |
43 | std::vector<CandidatePartition> PartitionRuleNode::AllCandidates( |
44 | const DataflowGraph& dataflow_graph, const PartitionSpec& spec) const { |
45 | ICHECK(false) << "PartitionRuleNode::AllCandidates should be overridden in sub-class" ; |
46 | return {}; |
47 | } |
48 | |
49 | std::string PartitionRuleNode::ToString() const { return ToDoc().str(); } |
50 | |
51 | Doc PartitionRuleNode::ToDoc() const { |
52 | Doc doc; |
53 | doc << GetTypeKey() << "(" << Doc::NewLine(2); |
54 | std::vector<Doc> body_items; |
55 | AppendBodyItems(&body_items); |
56 | doc << Doc::Indent(2, Doc::Concat(body_items, Doc::NewLine())) << Doc::NewLine(); |
57 | doc << ")" ; |
58 | return doc; |
59 | } |
60 | |
61 | void PartitionRuleNode::AppendBodyItems(std::vector<Doc>* body_items) const { |
62 | body_items->emplace_back(); |
63 | body_items->back() << "rule_name=" << Doc::StrLiteral(rule_name_); |
64 | } |
65 | |
66 | PartitionRule::PartitionRule(String rule_name) { |
67 | auto node = runtime::make_object<PartitionRuleNode>(); |
68 | node->rule_name_ = std::move(rule_name); |
69 | data_ = std::move(node); |
70 | } |
71 | |
72 | bool DefaultPatternPredicate(const Expr& matched_sub_expr) { return true; } |
73 | |
74 | TVM_REGISTER_NODE_TYPE(DFPatternPartitionRuleNode); |
75 | |
76 | void DFPatternPartitionRuleNode::VisitAttrs(AttrVisitor* v) { |
77 | // TODO(mbs) |
78 | } |
79 | |
80 | std::vector<CandidatePartition> DFPatternPartitionRuleNode::AllCandidates( |
81 | const DataflowGraph& dataflow_graph, const PartitionSpec& spec) const { |
82 | VLOG(1) << "running DFPatternPartitionRule(" << rule_name_ << ")" ; |
83 | std::vector<CandidatePartition> result; |
84 | DFPatternMatcher matcher(&dataflow_graph.indexed_graph()); |
85 | for (PostDfsIndex index = 0; index < dataflow_graph.size(); ++index) { |
86 | Expr sub_expr = dataflow_graph.index_to_node(index)->ref(); |
87 | if (!matcher.Match(pattern_, sub_expr)) { |
88 | continue; |
89 | } |
90 | if (!predicate_(sub_expr)) { |
91 | VLOG(1) << "DFPatternPartitionRule(" << rule_name_ << ") has failing predicate" ; |
92 | continue; |
93 | } |
94 | IndexSet inside = MatcherToIndexSet(matcher); |
95 | auto [kind, label] = SubGraphKindAndLabel(dataflow_graph, inside); |
96 | SubGraph sub_graph(dataflow_graph, std::move(inside), kind, std::move(label)); |
97 | String rule_name = rule_name_.empty() ? sub_graph->label_ : rule_name_; |
98 | CandidatePartition candidate(std::move(rule_name), std::move(sub_graph), spec); |
99 | VLOG(2) << "DFPatternPartitionRule(" << rule_name_ << ") yields " << candidate->ToString(); |
100 | result.emplace_back(std::move(candidate)); |
101 | } |
102 | VLOG(1) << "DFPatternPartitionRule(" << rule_name_ << ") produced " << result.size() |
103 | << " candidates" ; |
104 | return result; |
105 | } |
106 | |
107 | void DFPatternPartitionRuleNode::AppendBodyItems(std::vector<Doc>* body_items) const { |
108 | PartitionRuleNode::AppendBodyItems(body_items); |
109 | body_items->emplace_back(); |
110 | body_items->back() << "pattern=" << PrettyPrint(pattern_); |
111 | } |
112 | |
113 | DFPatternPartitionRule::DFPatternPartitionRule(String rule_name, DFPattern pattern, |
114 | TPatternPredicate predicate) { |
115 | auto node = runtime::make_object<DFPatternPartitionRuleNode>(); |
116 | node->rule_name_ = std::move(rule_name); |
117 | node->pattern_ = std::move(pattern); |
118 | node->predicate_ = std::move(predicate); |
119 | data_ = std::move(node); |
120 | } |
121 | |
122 | TVM_REGISTER_NODE_TYPE(CompositePartitionRuleNode); |
123 | |
124 | void CompositePartitionRuleNode::VisitAttrs(AttrVisitor* v) { |
125 | // TODO(mbs) |
126 | } |
127 | |
128 | std::vector<CandidatePartition> CompositePartitionRuleNode::AllCandidates( |
129 | const DataflowGraph& dataflow_graph, const PartitionSpec& spec) const { |
130 | std::vector<CandidatePartition> candidates = sub_rule_->AllCandidates(dataflow_graph, spec); |
131 | VLOG(1) << "running CompositePartitionRule(" << rule_name_ << ") over " << candidates.size() |
132 | << " sub-candidates" ; |
133 | std::vector<CandidatePartition> result; |
134 | FunctionAttrsMap attrs; |
135 | attrs.Set(attr::kComposite, rule_name_); |
136 | for (auto& candidate : candidates) { |
137 | String rule_name = NestLabels(rule_name_, candidate->rule_name_); |
138 | SubGraph sub_graph = candidate->sub_graph_.WithAttrs(dataflow_graph, attrs); |
139 | CandidatePartition new_candidate = WithSubGraph( |
140 | WithRuleName(std::move(candidate), std::move(rule_name)), std::move(sub_graph)); |
141 | VLOG(2) << "CompositePartitionRule(" << rule_name_ << ") yields " << new_candidate->ToString(); |
142 | result.emplace_back(std::move(new_candidate)); |
143 | } |
144 | VLOG(1) << "CompositePartitionRule(" << rule_name_ << ") produced " << result.size() |
145 | << " candidates" ; |
146 | return result; |
147 | } |
148 | |
149 | void CompositePartitionRuleNode::AppendBodyItems(std::vector<Doc>* body_items) const { |
150 | PartitionRuleNode::AppendBodyItems(body_items); |
151 | body_items->emplace_back(); |
152 | body_items->back() << "sub_rule=" << sub_rule_->ToDoc(); |
153 | } |
154 | |
155 | CompositePartitionRule::CompositePartitionRule(String rule_name, PartitionRule sub_rule) { |
156 | auto node = runtime::make_object<CompositePartitionRuleNode>(); |
157 | node->rule_name_ = std::move(rule_name); |
158 | node->sub_rule_ = std::move(sub_rule); |
159 | data_ = std::move(node); |
160 | } |
161 | |
162 | TVM_REGISTER_NODE_TYPE(PrimitivePartitionRuleNode); |
163 | |
164 | void PrimitivePartitionRuleNode::VisitAttrs(AttrVisitor* v) { |
165 | // TODO(mbs) |
166 | } |
167 | |
168 | std::vector<CandidatePartition> PrimitivePartitionRuleNode::AllCandidates( |
169 | const DataflowGraph& dataflow_graph, const PartitionSpec& spec) const { |
170 | std::vector<CandidatePartition> candidates = sub_rule_->AllCandidates(dataflow_graph, spec); |
171 | VLOG(1) << "running PrimitivePartitionRule(" << rule_name_ << ") over " << candidates.size() |
172 | << " sub-candidates" ; |
173 | std::vector<CandidatePartition> result; |
174 | FunctionAttrsMap attrs; |
175 | attrs.Set(attr::kPrimitive, Integer(1)); |
176 | if (spec->target_.IsExternalCodegen()) { |
177 | // The spec name will be the target kind name which is 1:1 with the "Compiler" attribute name. |
178 | attrs.Set(attr::kCompiler, spec->spec_name_); |
179 | } |
180 | for (auto& candidate : candidates) { |
181 | String rule_name = NestLabels(rule_name_, candidate->rule_name_); |
182 | SubGraph sub_graph = candidate->sub_graph_.WithAttrs(dataflow_graph, attrs); |
183 | CandidatePartition new_candidate = WithSubGraph( |
184 | WithRuleName(std::move(candidate), std::move(rule_name)), std::move(sub_graph)); |
185 | VLOG(2) << "PrimitivePartitionRule(" << rule_name_ << ") yields " << new_candidate->ToString(); |
186 | result.emplace_back(std::move(new_candidate)); |
187 | } |
188 | VLOG(1) << "PrimitivePartitionRule(" << rule_name_ << ") produced " << result.size() |
189 | << " candidates" ; |
190 | return result; |
191 | } |
192 | |
193 | void PrimitivePartitionRuleNode::AppendBodyItems(std::vector<Doc>* body_items) const { |
194 | PartitionRuleNode::AppendBodyItems(body_items); |
195 | body_items->emplace_back(); |
196 | body_items->back() << "sub_rule=" << sub_rule_->ToDoc(); |
197 | } |
198 | |
199 | PrimitivePartitionRule::PrimitivePartitionRule(String rule_name, PartitionRule sub_rule) { |
200 | auto node = runtime::make_object<PrimitivePartitionRuleNode>(); |
201 | node->rule_name_ = std::move(rule_name); |
202 | node->sub_rule_ = std::move(sub_rule); |
203 | data_ = std::move(node); |
204 | } |
205 | |
206 | TVM_REGISTER_NODE_TYPE(UnionPartitionRuleNode); |
207 | |
208 | void UnionPartitionRuleNode::VisitAttrs(AttrVisitor* v) { |
209 | // TODO(mbs) |
210 | } |
211 | |
212 | std::vector<CandidatePartition> UnionPartitionRuleNode::AllCandidates( |
213 | const DataflowGraph& dataflow_graph, const PartitionSpec& spec) const { |
214 | std::vector<CandidatePartition> result; |
215 | for (const auto& sub_rule : sub_rules_) { |
216 | std::vector<CandidatePartition> candidates = sub_rule->AllCandidates(dataflow_graph, spec); |
217 | for (auto& candidate : candidates) { |
218 | String rule_name = NestLabels(rule_name_, candidate->rule_name_); |
219 | CandidatePartition new_candidate = WithRuleName(std::move(candidate), std::move(rule_name)); |
220 | VLOG(2) << "UnionPartitionRule(" << rule_name_ << ") yields " << new_candidate->ToString(); |
221 | result.emplace_back(std::move(new_candidate)); |
222 | } |
223 | } |
224 | VLOG(1) << "UnionPartitionRule(" << rule_name_ << ") produced " << result.size() << " candidates" ; |
225 | return result; |
226 | } |
227 | |
228 | void UnionPartitionRuleNode::AppendBodyItems(std::vector<Doc>* body_items) const { |
229 | PartitionRuleNode::AppendBodyItems(body_items); |
230 | for (const auto& sub_rule : sub_rules_) { |
231 | body_items->emplace_back(); |
232 | body_items->back() << "sub_rule=" << sub_rule->ToDoc(); |
233 | } |
234 | } |
235 | |
236 | UnionPartitionRule::UnionPartitionRule(String rule_name, Array<PartitionRule> sub_rules) { |
237 | auto node = runtime::make_object<UnionPartitionRuleNode>(); |
238 | node->rule_name_ = std::move(rule_name); |
239 | node->sub_rules_ = std::move(sub_rules); |
240 | data_ = std::move(node); |
241 | } |
242 | |
243 | TVM_REGISTER_NODE_TYPE(OpCallByKindPartitionRuleNode); |
244 | |
245 | void OpCallByKindPartitionRuleNode::VisitAttrs(AttrVisitor* v) { |
246 | // TODO(mbs) |
247 | } |
248 | |
249 | std::vector<CandidatePartition> OpCallByKindPartitionRuleNode::AllCandidates( |
250 | const DataflowGraph& dataflow_graph, const PartitionSpec& spec) const { |
251 | VLOG(1) << "running OpCallByKindPartitionRule(" << rule_name_ << ")" ; |
252 | std::vector<CandidatePartition> result; |
253 | for (PostDfsIndex index = 0; index < dataflow_graph.size(); ++index) { |
254 | auto node = dataflow_graph.index_to_node(index); |
255 | Expr sub_expr = node->ref(); |
256 | if (sub_expr->IsInstance<CallNode>()) { |
257 | auto [kind, label] = SubExprKindAndLabel(sub_expr); |
258 | if (kind <= kOutEWiseFusable) { |
259 | IndexSet inside(dataflow_graph.size(), {index}); |
260 | SubGraph sub_graph(dataflow_graph, std::move(inside), kind, std::move(label)); |
261 | String rule_name = NestLabels(rule_name_, sub_graph->label_); |
262 | CandidatePartition candidate(std::move(rule_name), std::move(sub_graph), spec); |
263 | VLOG(2) << "OpCallByKindPartitionRule(" << rule_name_ << ") yields " |
264 | << candidate->ToString(); |
265 | result.emplace_back(std::move(candidate)); |
266 | } |
267 | } |
268 | } |
269 | VLOG(1) << "OpCallByKindPartitionRule(" << rule_name_ << ") produced " << result.size() |
270 | << " candidates" ; |
271 | return result; |
272 | } |
273 | |
274 | void OpCallByKindPartitionRuleNode::AppendBodyItems(std::vector<Doc>* body_items) const { |
275 | PartitionRuleNode::AppendBodyItems(body_items); |
276 | } |
277 | |
278 | OpCallByKindPartitionRule::OpCallByKindPartitionRule(String rule_name) { |
279 | auto node = runtime::make_object<OpCallByKindPartitionRuleNode>(); |
280 | node->rule_name_ = std::move(rule_name); |
281 | data_ = std::move(node); |
282 | } |
283 | |
284 | TVM_REGISTER_NODE_TYPE(CombinePartitionRuleNode); |
285 | |
286 | void CombinePartitionRuleNode::VisitAttrs(AttrVisitor* v) { |
287 | // TODO(mbs) |
288 | } |
289 | |
290 | std::vector<CandidatePartition> CombinePartitionRuleNode::AllCandidates( |
291 | const DataflowGraph& dataflow_graph, const PartitionSpec& spec) const { |
292 | // We'll accumulate all the candidates here, starting with those from the sub-rule. |
293 | // Once a candidate is added to this vector it is immutable. |
294 | std::vector<CandidatePartition> candidates = sub_rule_->AllCandidates(dataflow_graph, spec); |
295 | VLOG(1) << "running CombinePartitionRule(" << rule_name_ << ") over " << candidates.size() |
296 | << " sub-candidates" ; |
297 | CandidateSet result_set(std::move(candidates)); |
298 | |
299 | size_t num_rounds = 0; |
300 | AppendAllResultsContext ctxt(&dataflow_graph, max_depth_, &result_set); |
301 | while (result_set.PrepareForNextRound()) { |
302 | VLOG_CONTEXT << "round " << ++num_rounds; |
303 | VLOG(1) << "checking " << result_set.size() << " candidates (" << result_set.first_new_index() |
304 | << " existing)" ; |
305 | for (const auto& combiner_rule : combiner_rules_) { |
306 | combiner_rule->AppendAllResults(&ctxt); |
307 | } |
308 | } |
309 | |
310 | std::vector<CandidatePartition> result; |
311 | for (auto& candidate : result_set.MovedCurrentCandidates()) { |
312 | String rule_name = NestLabels(rule_name_, candidate->rule_name_); |
313 | CandidatePartition new_candidate = WithRuleName(std::move(candidate), std::move(rule_name)); |
314 | VLOG(2) << "CombinePartitionRule(" << rule_name_ << ") yields " << new_candidate->ToString(); |
315 | result.emplace_back(std::move(new_candidate)); |
316 | } |
317 | VLOG(1) << "CombinePartitionRule(" << rule_name_ << ") produced " << result.size() |
318 | << " candidates" ; |
319 | return result; |
320 | } |
321 | |
322 | void CombinePartitionRuleNode::AppendBodyItems(std::vector<Doc>* body_items) const { |
323 | PartitionRuleNode::AppendBodyItems(body_items); |
324 | body_items->emplace_back(); |
325 | body_items->back() << "sub_rule=" << sub_rule_->ToDoc(); |
326 | for (const auto& combiner_rule : combiner_rules_) { |
327 | body_items->emplace_back(); |
328 | body_items->back() << "combiner_rule=" << combiner_rule->ToString(); |
329 | } |
330 | body_items->emplace_back(); |
331 | body_items->back() << "max_depth=" << max_depth_; |
332 | } |
333 | |
334 | CombinePartitionRule::CombinePartitionRule(String rule_name, PartitionRule sub_rule, |
335 | Array<CombinerRule> combiner_rules, size_t max_depth_) { |
336 | auto node = runtime::make_object<CombinePartitionRuleNode>(); |
337 | node->rule_name_ = std::move(rule_name); |
338 | node->sub_rule_ = std::move(sub_rule); |
339 | node->combiner_rules_ = std::move(combiner_rules); |
340 | node->max_depth_ = max_depth_; |
341 | data_ = std::move(node); |
342 | } |
343 | |
344 | TVM_REGISTER_NODE_TYPE(OnlyValidPartitionRuleNode); |
345 | |
346 | void OnlyValidPartitionRuleNode::VisitAttrs(AttrVisitor* v) { |
347 | // TODO(mbs) |
348 | } |
349 | |
350 | std::vector<CandidatePartition> OnlyValidPartitionRuleNode::AllCandidates( |
351 | const DataflowGraph& dataflow_graph, const PartitionSpec& spec) const { |
352 | std::vector<CandidatePartition> candidates = sub_rule_->AllCandidates(dataflow_graph, spec); |
353 | VLOG(1) << "running OnlyValidPartitionRule(" << rule_name_ << ") over " << candidates.size() |
354 | << " sub-candidates" ; |
355 | std::vector<CandidatePartition> result; |
356 | for (auto& candidate : candidates) { |
357 | if (!candidate->sub_graph_->IsValid(dataflow_graph, config_)) { |
358 | VLOG(2) << "Ignoring invalid candidate " << candidate->ToString(); |
359 | continue; |
360 | } |
361 | String rule_name = NestLabels(rule_name_, candidate->rule_name_); |
362 | CandidatePartition new_candidate = WithRuleName(std::move(candidate), std::move(rule_name)); |
363 | VLOG(2) << "OnlyValidPartitionRule(" << rule_name_ << ") yields " << new_candidate->ToString(); |
364 | result.emplace_back(std::move(new_candidate)); |
365 | } |
366 | VLOG(1) << "OnlyValidPartitionRule(" << rule_name_ << ") produced " << result.size() |
367 | << " candidates" ; |
368 | return result; |
369 | } |
370 | |
371 | void OnlyValidPartitionRuleNode::AppendBodyItems(std::vector<Doc>* body_items) const { |
372 | PartitionRuleNode::AppendBodyItems(body_items); |
373 | body_items->emplace_back(); |
374 | body_items->back() << "sub_rule=" << sub_rule_->ToDoc(); |
375 | body_items->emplace_back(); |
376 | body_items->back() << "config=" << config_.ToString(); |
377 | } |
378 | |
379 | OnlyValidPartitionRule::OnlyValidPartitionRule(String rule_name, PartitionRule sub_rule, |
380 | const SubGraphConfig& config) { |
381 | auto node = runtime::make_object<OnlyValidPartitionRuleNode>(); |
382 | node->rule_name_ = std::move(rule_name); |
383 | node->sub_rule_ = std::move(sub_rule); |
384 | node->config_ = config; |
385 | data_ = std::move(node); |
386 | } |
387 | |
388 | TVM_REGISTER_NODE_TYPE(HostPartitionRuleNode); |
389 | |
390 | void HostPartitionRuleNode::VisitAttrs(AttrVisitor* v) { |
391 | // TODO(mbs) |
392 | } |
393 | |
394 | std::vector<CandidatePartition> HostPartitionRuleNode::AllCandidates( |
395 | const DataflowGraph& dataflow_graph, const PartitionSpec& spec) const { |
396 | VLOG(1) << "running HostPartitionRule(" << rule_name_ << ")" ; |
397 | std::vector<CandidatePartition> result; |
398 | for (PostDfsIndex index = 0; index < dataflow_graph.size(); ++index) { |
399 | if (MustBeLowered(dataflow_graph.index_to_node(index)->ref())) { |
400 | continue; |
401 | } |
402 | IndexSet inside(dataflow_graph.size(), {index}); |
403 | auto [kind, label] = SubGraphKindAndLabel(dataflow_graph, inside); |
404 | SubGraph sub_graph(dataflow_graph, std::move(inside), kind, label); |
405 | String rule_name = NestLabels(rule_name_, sub_graph->label_); |
406 | // We'll a zero cost for the candidate since we'll never want to actually estimate the cost |
407 | // of this 'partition'. |
408 | CandidatePartition candidate(std::move(rule_name), std::move(sub_graph), spec, Cost::Zero()); |
409 | VLOG(2) << "HostPartitionRule(" << rule_name_ << ") yields " << candidate->ToString(); |
410 | result.push_back(candidate); |
411 | } |
412 | VLOG(1) << "HostPartitionRule(" << rule_name_ << ") produced " << result.size() << " candidates" ; |
413 | return result; |
414 | } |
415 | |
416 | void HostPartitionRuleNode::AppendBodyItems(std::vector<Doc>* body_items) const {} |
417 | |
418 | HostPartitionRule::HostPartitionRule(String rule_name) { |
419 | auto node = runtime::make_object<HostPartitionRuleNode>(); |
420 | node->rule_name_ = std::move(rule_name); |
421 | data_ = std::move(node); |
422 | } |
423 | |
424 | } // namespace collage |
425 | } // namespace relay |
426 | } // namespace tvm |
427 | |