1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/collage/partition_rule.cc
22 * \brief Compositional partitioning rules.
23 */
24
25#include "./partition_rule.h"
26
27#include <tvm/relay/transform.h>
28
29#include "./partition_rule.h"
30#include "./partition_spec.h"
31#include "./utils.h"
32
33namespace tvm {
34namespace relay {
35namespace collage {
36
37TVM_REGISTER_NODE_TYPE(PartitionRuleNode);
38
39void PartitionRuleNode::VisitAttrs(AttrVisitor* v) {
40 // TODO(mbs)
41}
42
43std::vector<CandidatePartition> PartitionRuleNode::AllCandidates(
44 const DataflowGraph& dataflow_graph, const PartitionSpec& spec) const {
45 ICHECK(false) << "PartitionRuleNode::AllCandidates should be overridden in sub-class";
46 return {};
47}
48
49std::string PartitionRuleNode::ToString() const { return ToDoc().str(); }
50
51Doc PartitionRuleNode::ToDoc() const {
52 Doc doc;
53 doc << GetTypeKey() << "(" << Doc::NewLine(2);
54 std::vector<Doc> body_items;
55 AppendBodyItems(&body_items);
56 doc << Doc::Indent(2, Doc::Concat(body_items, Doc::NewLine())) << Doc::NewLine();
57 doc << ")";
58 return doc;
59}
60
61void PartitionRuleNode::AppendBodyItems(std::vector<Doc>* body_items) const {
62 body_items->emplace_back();
63 body_items->back() << "rule_name=" << Doc::StrLiteral(rule_name_);
64}
65
66PartitionRule::PartitionRule(String rule_name) {
67 auto node = runtime::make_object<PartitionRuleNode>();
68 node->rule_name_ = std::move(rule_name);
69 data_ = std::move(node);
70}
71
72bool DefaultPatternPredicate(const Expr& matched_sub_expr) { return true; }
73
74TVM_REGISTER_NODE_TYPE(DFPatternPartitionRuleNode);
75
76void DFPatternPartitionRuleNode::VisitAttrs(AttrVisitor* v) {
77 // TODO(mbs)
78}
79
80std::vector<CandidatePartition> DFPatternPartitionRuleNode::AllCandidates(
81 const DataflowGraph& dataflow_graph, const PartitionSpec& spec) const {
82 VLOG(1) << "running DFPatternPartitionRule(" << rule_name_ << ")";
83 std::vector<CandidatePartition> result;
84 DFPatternMatcher matcher(&dataflow_graph.indexed_graph());
85 for (PostDfsIndex index = 0; index < dataflow_graph.size(); ++index) {
86 Expr sub_expr = dataflow_graph.index_to_node(index)->ref();
87 if (!matcher.Match(pattern_, sub_expr)) {
88 continue;
89 }
90 if (!predicate_(sub_expr)) {
91 VLOG(1) << "DFPatternPartitionRule(" << rule_name_ << ") has failing predicate";
92 continue;
93 }
94 IndexSet inside = MatcherToIndexSet(matcher);
95 auto [kind, label] = SubGraphKindAndLabel(dataflow_graph, inside);
96 SubGraph sub_graph(dataflow_graph, std::move(inside), kind, std::move(label));
97 String rule_name = rule_name_.empty() ? sub_graph->label_ : rule_name_;
98 CandidatePartition candidate(std::move(rule_name), std::move(sub_graph), spec);
99 VLOG(2) << "DFPatternPartitionRule(" << rule_name_ << ") yields " << candidate->ToString();
100 result.emplace_back(std::move(candidate));
101 }
102 VLOG(1) << "DFPatternPartitionRule(" << rule_name_ << ") produced " << result.size()
103 << " candidates";
104 return result;
105}
106
107void DFPatternPartitionRuleNode::AppendBodyItems(std::vector<Doc>* body_items) const {
108 PartitionRuleNode::AppendBodyItems(body_items);
109 body_items->emplace_back();
110 body_items->back() << "pattern=" << PrettyPrint(pattern_);
111}
112
113DFPatternPartitionRule::DFPatternPartitionRule(String rule_name, DFPattern pattern,
114 TPatternPredicate predicate) {
115 auto node = runtime::make_object<DFPatternPartitionRuleNode>();
116 node->rule_name_ = std::move(rule_name);
117 node->pattern_ = std::move(pattern);
118 node->predicate_ = std::move(predicate);
119 data_ = std::move(node);
120}
121
122TVM_REGISTER_NODE_TYPE(CompositePartitionRuleNode);
123
124void CompositePartitionRuleNode::VisitAttrs(AttrVisitor* v) {
125 // TODO(mbs)
126}
127
128std::vector<CandidatePartition> CompositePartitionRuleNode::AllCandidates(
129 const DataflowGraph& dataflow_graph, const PartitionSpec& spec) const {
130 std::vector<CandidatePartition> candidates = sub_rule_->AllCandidates(dataflow_graph, spec);
131 VLOG(1) << "running CompositePartitionRule(" << rule_name_ << ") over " << candidates.size()
132 << " sub-candidates";
133 std::vector<CandidatePartition> result;
134 FunctionAttrsMap attrs;
135 attrs.Set(attr::kComposite, rule_name_);
136 for (auto& candidate : candidates) {
137 String rule_name = NestLabels(rule_name_, candidate->rule_name_);
138 SubGraph sub_graph = candidate->sub_graph_.WithAttrs(dataflow_graph, attrs);
139 CandidatePartition new_candidate = WithSubGraph(
140 WithRuleName(std::move(candidate), std::move(rule_name)), std::move(sub_graph));
141 VLOG(2) << "CompositePartitionRule(" << rule_name_ << ") yields " << new_candidate->ToString();
142 result.emplace_back(std::move(new_candidate));
143 }
144 VLOG(1) << "CompositePartitionRule(" << rule_name_ << ") produced " << result.size()
145 << " candidates";
146 return result;
147}
148
149void CompositePartitionRuleNode::AppendBodyItems(std::vector<Doc>* body_items) const {
150 PartitionRuleNode::AppendBodyItems(body_items);
151 body_items->emplace_back();
152 body_items->back() << "sub_rule=" << sub_rule_->ToDoc();
153}
154
155CompositePartitionRule::CompositePartitionRule(String rule_name, PartitionRule sub_rule) {
156 auto node = runtime::make_object<CompositePartitionRuleNode>();
157 node->rule_name_ = std::move(rule_name);
158 node->sub_rule_ = std::move(sub_rule);
159 data_ = std::move(node);
160}
161
162TVM_REGISTER_NODE_TYPE(PrimitivePartitionRuleNode);
163
164void PrimitivePartitionRuleNode::VisitAttrs(AttrVisitor* v) {
165 // TODO(mbs)
166}
167
168std::vector<CandidatePartition> PrimitivePartitionRuleNode::AllCandidates(
169 const DataflowGraph& dataflow_graph, const PartitionSpec& spec) const {
170 std::vector<CandidatePartition> candidates = sub_rule_->AllCandidates(dataflow_graph, spec);
171 VLOG(1) << "running PrimitivePartitionRule(" << rule_name_ << ") over " << candidates.size()
172 << " sub-candidates";
173 std::vector<CandidatePartition> result;
174 FunctionAttrsMap attrs;
175 attrs.Set(attr::kPrimitive, Integer(1));
176 if (spec->target_.IsExternalCodegen()) {
177 // The spec name will be the target kind name which is 1:1 with the "Compiler" attribute name.
178 attrs.Set(attr::kCompiler, spec->spec_name_);
179 }
180 for (auto& candidate : candidates) {
181 String rule_name = NestLabels(rule_name_, candidate->rule_name_);
182 SubGraph sub_graph = candidate->sub_graph_.WithAttrs(dataflow_graph, attrs);
183 CandidatePartition new_candidate = WithSubGraph(
184 WithRuleName(std::move(candidate), std::move(rule_name)), std::move(sub_graph));
185 VLOG(2) << "PrimitivePartitionRule(" << rule_name_ << ") yields " << new_candidate->ToString();
186 result.emplace_back(std::move(new_candidate));
187 }
188 VLOG(1) << "PrimitivePartitionRule(" << rule_name_ << ") produced " << result.size()
189 << " candidates";
190 return result;
191}
192
193void PrimitivePartitionRuleNode::AppendBodyItems(std::vector<Doc>* body_items) const {
194 PartitionRuleNode::AppendBodyItems(body_items);
195 body_items->emplace_back();
196 body_items->back() << "sub_rule=" << sub_rule_->ToDoc();
197}
198
199PrimitivePartitionRule::PrimitivePartitionRule(String rule_name, PartitionRule sub_rule) {
200 auto node = runtime::make_object<PrimitivePartitionRuleNode>();
201 node->rule_name_ = std::move(rule_name);
202 node->sub_rule_ = std::move(sub_rule);
203 data_ = std::move(node);
204}
205
206TVM_REGISTER_NODE_TYPE(UnionPartitionRuleNode);
207
208void UnionPartitionRuleNode::VisitAttrs(AttrVisitor* v) {
209 // TODO(mbs)
210}
211
212std::vector<CandidatePartition> UnionPartitionRuleNode::AllCandidates(
213 const DataflowGraph& dataflow_graph, const PartitionSpec& spec) const {
214 std::vector<CandidatePartition> result;
215 for (const auto& sub_rule : sub_rules_) {
216 std::vector<CandidatePartition> candidates = sub_rule->AllCandidates(dataflow_graph, spec);
217 for (auto& candidate : candidates) {
218 String rule_name = NestLabels(rule_name_, candidate->rule_name_);
219 CandidatePartition new_candidate = WithRuleName(std::move(candidate), std::move(rule_name));
220 VLOG(2) << "UnionPartitionRule(" << rule_name_ << ") yields " << new_candidate->ToString();
221 result.emplace_back(std::move(new_candidate));
222 }
223 }
224 VLOG(1) << "UnionPartitionRule(" << rule_name_ << ") produced " << result.size() << " candidates";
225 return result;
226}
227
228void UnionPartitionRuleNode::AppendBodyItems(std::vector<Doc>* body_items) const {
229 PartitionRuleNode::AppendBodyItems(body_items);
230 for (const auto& sub_rule : sub_rules_) {
231 body_items->emplace_back();
232 body_items->back() << "sub_rule=" << sub_rule->ToDoc();
233 }
234}
235
236UnionPartitionRule::UnionPartitionRule(String rule_name, Array<PartitionRule> sub_rules) {
237 auto node = runtime::make_object<UnionPartitionRuleNode>();
238 node->rule_name_ = std::move(rule_name);
239 node->sub_rules_ = std::move(sub_rules);
240 data_ = std::move(node);
241}
242
243TVM_REGISTER_NODE_TYPE(OpCallByKindPartitionRuleNode);
244
245void OpCallByKindPartitionRuleNode::VisitAttrs(AttrVisitor* v) {
246 // TODO(mbs)
247}
248
249std::vector<CandidatePartition> OpCallByKindPartitionRuleNode::AllCandidates(
250 const DataflowGraph& dataflow_graph, const PartitionSpec& spec) const {
251 VLOG(1) << "running OpCallByKindPartitionRule(" << rule_name_ << ")";
252 std::vector<CandidatePartition> result;
253 for (PostDfsIndex index = 0; index < dataflow_graph.size(); ++index) {
254 auto node = dataflow_graph.index_to_node(index);
255 Expr sub_expr = node->ref();
256 if (sub_expr->IsInstance<CallNode>()) {
257 auto [kind, label] = SubExprKindAndLabel(sub_expr);
258 if (kind <= kOutEWiseFusable) {
259 IndexSet inside(dataflow_graph.size(), {index});
260 SubGraph sub_graph(dataflow_graph, std::move(inside), kind, std::move(label));
261 String rule_name = NestLabels(rule_name_, sub_graph->label_);
262 CandidatePartition candidate(std::move(rule_name), std::move(sub_graph), spec);
263 VLOG(2) << "OpCallByKindPartitionRule(" << rule_name_ << ") yields "
264 << candidate->ToString();
265 result.emplace_back(std::move(candidate));
266 }
267 }
268 }
269 VLOG(1) << "OpCallByKindPartitionRule(" << rule_name_ << ") produced " << result.size()
270 << " candidates";
271 return result;
272}
273
274void OpCallByKindPartitionRuleNode::AppendBodyItems(std::vector<Doc>* body_items) const {
275 PartitionRuleNode::AppendBodyItems(body_items);
276}
277
278OpCallByKindPartitionRule::OpCallByKindPartitionRule(String rule_name) {
279 auto node = runtime::make_object<OpCallByKindPartitionRuleNode>();
280 node->rule_name_ = std::move(rule_name);
281 data_ = std::move(node);
282}
283
284TVM_REGISTER_NODE_TYPE(CombinePartitionRuleNode);
285
286void CombinePartitionRuleNode::VisitAttrs(AttrVisitor* v) {
287 // TODO(mbs)
288}
289
290std::vector<CandidatePartition> CombinePartitionRuleNode::AllCandidates(
291 const DataflowGraph& dataflow_graph, const PartitionSpec& spec) const {
292 // We'll accumulate all the candidates here, starting with those from the sub-rule.
293 // Once a candidate is added to this vector it is immutable.
294 std::vector<CandidatePartition> candidates = sub_rule_->AllCandidates(dataflow_graph, spec);
295 VLOG(1) << "running CombinePartitionRule(" << rule_name_ << ") over " << candidates.size()
296 << " sub-candidates";
297 CandidateSet result_set(std::move(candidates));
298
299 size_t num_rounds = 0;
300 AppendAllResultsContext ctxt(&dataflow_graph, max_depth_, &result_set);
301 while (result_set.PrepareForNextRound()) {
302 VLOG_CONTEXT << "round " << ++num_rounds;
303 VLOG(1) << "checking " << result_set.size() << " candidates (" << result_set.first_new_index()
304 << " existing)";
305 for (const auto& combiner_rule : combiner_rules_) {
306 combiner_rule->AppendAllResults(&ctxt);
307 }
308 }
309
310 std::vector<CandidatePartition> result;
311 for (auto& candidate : result_set.MovedCurrentCandidates()) {
312 String rule_name = NestLabels(rule_name_, candidate->rule_name_);
313 CandidatePartition new_candidate = WithRuleName(std::move(candidate), std::move(rule_name));
314 VLOG(2) << "CombinePartitionRule(" << rule_name_ << ") yields " << new_candidate->ToString();
315 result.emplace_back(std::move(new_candidate));
316 }
317 VLOG(1) << "CombinePartitionRule(" << rule_name_ << ") produced " << result.size()
318 << " candidates";
319 return result;
320}
321
322void CombinePartitionRuleNode::AppendBodyItems(std::vector<Doc>* body_items) const {
323 PartitionRuleNode::AppendBodyItems(body_items);
324 body_items->emplace_back();
325 body_items->back() << "sub_rule=" << sub_rule_->ToDoc();
326 for (const auto& combiner_rule : combiner_rules_) {
327 body_items->emplace_back();
328 body_items->back() << "combiner_rule=" << combiner_rule->ToString();
329 }
330 body_items->emplace_back();
331 body_items->back() << "max_depth=" << max_depth_;
332}
333
334CombinePartitionRule::CombinePartitionRule(String rule_name, PartitionRule sub_rule,
335 Array<CombinerRule> combiner_rules, size_t max_depth_) {
336 auto node = runtime::make_object<CombinePartitionRuleNode>();
337 node->rule_name_ = std::move(rule_name);
338 node->sub_rule_ = std::move(sub_rule);
339 node->combiner_rules_ = std::move(combiner_rules);
340 node->max_depth_ = max_depth_;
341 data_ = std::move(node);
342}
343
344TVM_REGISTER_NODE_TYPE(OnlyValidPartitionRuleNode);
345
346void OnlyValidPartitionRuleNode::VisitAttrs(AttrVisitor* v) {
347 // TODO(mbs)
348}
349
350std::vector<CandidatePartition> OnlyValidPartitionRuleNode::AllCandidates(
351 const DataflowGraph& dataflow_graph, const PartitionSpec& spec) const {
352 std::vector<CandidatePartition> candidates = sub_rule_->AllCandidates(dataflow_graph, spec);
353 VLOG(1) << "running OnlyValidPartitionRule(" << rule_name_ << ") over " << candidates.size()
354 << " sub-candidates";
355 std::vector<CandidatePartition> result;
356 for (auto& candidate : candidates) {
357 if (!candidate->sub_graph_->IsValid(dataflow_graph, config_)) {
358 VLOG(2) << "Ignoring invalid candidate " << candidate->ToString();
359 continue;
360 }
361 String rule_name = NestLabels(rule_name_, candidate->rule_name_);
362 CandidatePartition new_candidate = WithRuleName(std::move(candidate), std::move(rule_name));
363 VLOG(2) << "OnlyValidPartitionRule(" << rule_name_ << ") yields " << new_candidate->ToString();
364 result.emplace_back(std::move(new_candidate));
365 }
366 VLOG(1) << "OnlyValidPartitionRule(" << rule_name_ << ") produced " << result.size()
367 << " candidates";
368 return result;
369}
370
371void OnlyValidPartitionRuleNode::AppendBodyItems(std::vector<Doc>* body_items) const {
372 PartitionRuleNode::AppendBodyItems(body_items);
373 body_items->emplace_back();
374 body_items->back() << "sub_rule=" << sub_rule_->ToDoc();
375 body_items->emplace_back();
376 body_items->back() << "config=" << config_.ToString();
377}
378
379OnlyValidPartitionRule::OnlyValidPartitionRule(String rule_name, PartitionRule sub_rule,
380 const SubGraphConfig& config) {
381 auto node = runtime::make_object<OnlyValidPartitionRuleNode>();
382 node->rule_name_ = std::move(rule_name);
383 node->sub_rule_ = std::move(sub_rule);
384 node->config_ = config;
385 data_ = std::move(node);
386}
387
388TVM_REGISTER_NODE_TYPE(HostPartitionRuleNode);
389
390void HostPartitionRuleNode::VisitAttrs(AttrVisitor* v) {
391 // TODO(mbs)
392}
393
394std::vector<CandidatePartition> HostPartitionRuleNode::AllCandidates(
395 const DataflowGraph& dataflow_graph, const PartitionSpec& spec) const {
396 VLOG(1) << "running HostPartitionRule(" << rule_name_ << ")";
397 std::vector<CandidatePartition> result;
398 for (PostDfsIndex index = 0; index < dataflow_graph.size(); ++index) {
399 if (MustBeLowered(dataflow_graph.index_to_node(index)->ref())) {
400 continue;
401 }
402 IndexSet inside(dataflow_graph.size(), {index});
403 auto [kind, label] = SubGraphKindAndLabel(dataflow_graph, inside);
404 SubGraph sub_graph(dataflow_graph, std::move(inside), kind, label);
405 String rule_name = NestLabels(rule_name_, sub_graph->label_);
406 // We'll a zero cost for the candidate since we'll never want to actually estimate the cost
407 // of this 'partition'.
408 CandidatePartition candidate(std::move(rule_name), std::move(sub_graph), spec, Cost::Zero());
409 VLOG(2) << "HostPartitionRule(" << rule_name_ << ") yields " << candidate->ToString();
410 result.push_back(candidate);
411 }
412 VLOG(1) << "HostPartitionRule(" << rule_name_ << ") produced " << result.size() << " candidates";
413 return result;
414}
415
416void HostPartitionRuleNode::AppendBodyItems(std::vector<Doc>* body_items) const {}
417
418HostPartitionRule::HostPartitionRule(String rule_name) {
419 auto node = runtime::make_object<HostPartitionRuleNode>();
420 node->rule_name_ = std::move(rule_name);
421 data_ = std::move(node);
422}
423
424} // namespace collage
425} // namespace relay
426} // namespace tvm
427