1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/collage/combiner_rule.cc
22 * \brief Helpers for the \p CombinePartitionRule
23 */
24
25#include "./combiner_rule.h"
26
27#include "./partition_spec.h"
28
29namespace tvm {
30namespace relay {
31namespace collage {
32
33TVM_REGISTER_NODE_TYPE(SimpleCombinerRuleNode);
34
35void SimpleCombinerRuleNode::VisitAttrs(AttrVisitor* v) {
36 // TODO(mbs)
37}
38
39bool SimpleCombinerRuleNode::Fires(const DataflowGraph& dataflow_graph,
40 const CandidatePartition& upstream,
41 const CandidatePartition& downstream) const {
42 return false;
43}
44
45std::string SimpleCombinerRuleNode::ToString() const {
46 return "SimpleCombinerRule(" + rule_name_ + ")";
47}
48
49SimpleCombinerRule::SimpleCombinerRule(String rule_name) {
50 auto node = runtime::make_object<SimpleCombinerRuleNode>();
51 node->rule_name_ = std::move(rule_name);
52 data_ = std::move(node);
53}
54
55TVM_REGISTER_NODE_TYPE(ByKindSimpleCombinerRuleNode);
56
57void ByKindSimpleCombinerRuleNode::VisitAttrs(AttrVisitor* v) {
58 // TODO(mbs)
59}
60
61bool ByKindSimpleCombinerRuleNode::Fires(const DataflowGraph& dataflow_graph,
62 const CandidatePartition& upstream,
63 const CandidatePartition& downstream) const {
64 return upstream->sub_graph_->kind_ <= upstream_kind_ &&
65 downstream->sub_graph_->kind_ <= downstream_kind_;
66}
67
68std::string ByKindSimpleCombinerRuleNode::ToString() const {
69 std::ostringstream os;
70 os << "ByKindSimpleCombinerRule(" << rule_name_ << ")";
71 return os.str();
72}
73
74ByKindSimpleCombinerRule::ByKindSimpleCombinerRule(OpPatternKind upstream_kind,
75 OpPatternKind downstream_kind) {
76 auto node = runtime::make_object<ByKindSimpleCombinerRuleNode>();
77 String rule_name = KindToString(upstream_kind) + "->" + KindToString(downstream_kind);
78 node->rule_name_ = std::move(rule_name);
79 node->upstream_kind_ = upstream_kind;
80 node->downstream_kind_ = downstream_kind;
81 data_ = std::move(node);
82}
83
84TVM_REGISTER_NODE_TYPE(CombinerRuleNode);
85
86void CombinerRuleNode::VisitAttrs(AttrVisitor* v) {
87 // TODO(mbs)
88}
89
90void CombinerRuleNode::AppendAllResults(AppendAllResultsContext* ctxt) const {}
91
92std::string CombinerRuleNode::ToString() const { return "CombinerRuleNode(" + rule_name_ + ")"; }
93
94CombinerRule::CombinerRule(String rule_name) {
95 auto node = runtime::make_object<CombinerRuleNode>();
96 node->rule_name_ = std::move(rule_name);
97 data_ = std::move(node);
98}
99
100TVM_REGISTER_NODE_TYPE(AllSimpleCombinerRuleNode);
101
102void AllSimpleCombinerRuleNode::VisitAttrs(AttrVisitor* v) {
103 // TODO(mbs)
104}
105
106void AllSimpleCombinerRuleNode::AppendAllResults(AppendAllResultsContext* ctxt) const {
107 VLOG(1) << "running AllSimpleCombinerRule(" << rule_name_ << ")";
108 // Build map from post-dfs indices to the indices of candidates with corresponding entry node.
109 // NOTE: the index set is over candidate indices not post-dfs indices!
110 std::vector<IndexSet> entry_map(ctxt->dataflow_graph->size(),
111 IndexSet(ctxt->candidate_set->size()));
112 for (size_t i = 0; i < ctxt->candidate_set->size(); ++i) {
113 CandidatePartition candidate = ctxt->candidate_set->at(i);
114 for (PostDfsIndex entry_index : candidate->sub_graph_->entry_) {
115 entry_map[entry_index].Add(i);
116 }
117 }
118
119 for (size_t i = 0; i < ctxt->candidate_set->size(); ++i) {
120 CandidatePartition upstream = ctxt->candidate_set->at(i);
121 // Narrow our search to just those candidates which could touch.
122 IndexSet possible_downstream(ctxt->candidate_set->size());
123 for (PostDfsIndex output_index : upstream->sub_graph_->output_) {
124 possible_downstream = possible_downstream | entry_map[output_index];
125 }
126 size_t start_j =
127 i < ctxt->candidate_set->first_new_index() ? ctxt->candidate_set->first_new_index() : 0;
128 for (size_t j : possible_downstream) {
129 if (i == j) {
130 continue;
131 }
132 if (i < start_j) {
133 // We already explored the cross-product of candidates [0, first_new_index), so don't
134 // do it again.
135 continue;
136 }
137 // Note that the rules are not commutative so we can't just ignore if j < i.
138 CandidatePartition downstream = ctxt->candidate_set->at(j);
139 if (ctxt->max_depth > 0 &&
140 upstream->sub_graph_->depth_ + downstream->sub_graph_->depth_ > ctxt->max_depth) {
141 continue;
142 }
143 if (!upstream.AreTouching(*ctxt->dataflow_graph, downstream)) {
144 continue;
145 }
146 for (const auto& simple_rule : simple_rules_) {
147 if (simple_rule->Fires(*ctxt->dataflow_graph, upstream, downstream)) {
148 CandidatePartition new_candidate =
149 upstream.DisjointUnion(*ctxt->dataflow_graph, downstream);
150 VLOG(2) << "Fired " << simple_rule->rule_name_ << " on upstream candidate "
151 << upstream->ToString() << " and downstream candidate " << downstream->ToString()
152 << " to yield " << new_candidate->ToString();
153 ctxt->candidate_set->Add(*ctxt->dataflow_graph, new_candidate);
154 }
155 }
156 }
157 }
158}
159
160std::string AllSimpleCombinerRuleNode::ToString() const {
161 std::ostringstream os;
162 os << "AllSimpleCombinerRule(" << rule_name_;
163 for (const auto& simple : simple_rules_) {
164 os << ", " << simple->ToString();
165 }
166 os << ")";
167 return os.str();
168}
169
170AllSimpleCombinerRule::AllSimpleCombinerRule(String rule_name,
171 Array<SimpleCombinerRule> simple_rules) {
172 auto node = runtime::make_object<AllSimpleCombinerRuleNode>();
173 node->rule_name_ = std::move(rule_name);
174 node->simple_rules_ = std::move(simple_rules);
175 data_ = std::move(node);
176}
177
178TVM_REGISTER_NODE_TYPE(TupleArgCombinerRuleNode);
179
180void TupleArgCombinerRuleNode::VisitAttrs(AttrVisitor* v) {
181 // TODO(mbs)
182}
183
184void TupleArgCombinerRuleNode::AppendAllResults(AppendAllResultsContext* ctxt) const {
185 VLOG(1) << "running TupleArgCombinerRule(" << rule_name_ << ")";
186 // Build map from post-dfs index to the indices of injective candidates with corresponding entry
187 // node. NOTE: the index set is over candidate indices not post-dfs indices!
188 std::vector<IndexSet> exit_map(ctxt->dataflow_graph->size(),
189 IndexSet(ctxt->candidate_set->size()));
190 for (size_t i = 0; i < ctxt->candidate_set->size(); ++i) {
191 CandidatePartition candidate = ctxt->candidate_set->at(i);
192 if (candidate->sub_graph_->kind_ > kInjective) {
193 continue;
194 }
195 for (PostDfsIndex exit_index : candidate->sub_graph_->exit_) {
196 exit_map[exit_index].Add(i);
197 }
198 }
199
200 // The two-step I -> tuple -> I rule.
201 // Look all possible tuple consumers...
202 for (size_t i = 0; i < ctxt->candidate_set->size(); ++i) {
203 CandidatePartition tuple_consumer_candidate = ctxt->candidate_set->at(i);
204 if (tuple_consumer_candidate->sub_graph_->kind_ > kInjective) {
205 continue;
206 }
207 // For all possible tuples feeding into candidate...
208 for (PostDfsIndex input_index : tuple_consumer_candidate->sub_graph_->input_) {
209 auto node = ctxt->dataflow_graph->index_to_node(input_index);
210 Expr sub_expr = node->ref();
211 const auto* tuple_node = sub_expr.as<TupleNode>();
212 if (tuple_node == nullptr) {
213 continue;
214 }
215 // The tuple_consumer_candidate candidate consumes (at least one) tuple, eg as an argument
216 // to an operator.
217 // eg: concatenate((field1, ..., fieldn))
218 auto tuple_dataflow_node = ctxt->dataflow_graph->item_to_node(tuple_node);
219
220 // Collect all the possible unions. There may be more than one if different candidates
221 // could supply the same tuple field.
222 std::vector<std::vector<CandidatePartition>> all_possible_unions;
223
224 // Obviously we must include the consumer.
225 all_possible_unions.emplace_back();
226 all_possible_unions.back().emplace_back(tuple_consumer_candidate);
227
228 // We must include the tuple itself.
229 SubGraph tuple_sub_graph(*ctxt->dataflow_graph,
230 IndexSet(ctxt->dataflow_graph->size(), {node->index_}), kInjective,
231 "tuple");
232 CandidatePartition tuple_candidate("", std::move(tuple_sub_graph),
233 tuple_consumer_candidate->partition_spec());
234 all_possible_unions.back().emplace_back(std::move(tuple_candidate));
235
236 // For all tuple fields...
237 bool all_tuple_fields_have_producer = true;
238 for (auto* tuple_field_dataflow_node : tuple_dataflow_node->inputs_) {
239 // Collect all the candidates which could produce this tuple field.
240 std::vector<CandidatePartition> to_appends;
241 size_t start_j =
242 i < ctxt->candidate_set->first_new_index() ? ctxt->candidate_set->first_new_index() : 0;
243 for (size_t j : exit_map[tuple_field_dataflow_node->index_]) {
244 if (i == j) {
245 continue;
246 }
247 if (i < start_j) {
248 // We already explored the cross-product of candidates [0, first_new_index), so don't
249 // do it again.
250 continue;
251 }
252 CandidatePartition tuple_field_producer = ctxt->candidate_set->at(j);
253 // The tuple_field_producer candidate can provide this tuple field.
254 // eg concatenate((..., producer, ...))
255 to_appends.emplace_back(tuple_field_producer);
256 }
257 if (to_appends.empty()) {
258 // At least one of the tuple's fields does not have a producer candidate we can
259 // union in, so we need to give up.
260 all_tuple_fields_have_producer = false;
261 break;
262 } else {
263 // If to_appends = [A, B] and we already have possible unions [C, D] and [E, F] then
264 // the new possible unions are [C, D, A], [C, D, B], [E, F, A] and [E, F, B].
265 std::vector<std::vector<CandidatePartition>> new_all_possible_unions;
266 for (const auto& to_append : to_appends) {
267 for (const auto& possible_union : all_possible_unions) {
268 new_all_possible_unions.emplace_back(possible_union);
269 new_all_possible_unions.back().emplace_back(to_append);
270 }
271 }
272 all_possible_unions = std::move(new_all_possible_unions);
273 }
274 }
275
276 if (!all_tuple_fields_have_producer) {
277 continue;
278 }
279
280 // Actually build the candidates which union according to all_possible_unions.
281 for (const auto& possible_union : all_possible_unions) {
282 if (possible_union.size() > 2) {
283 CandidatePartition new_candidate =
284 CandidatePartition::DisjointUnion(*ctxt->dataflow_graph, possible_union);
285#if TVM_LOG_DEBUG
286 std::ostringstream os;
287 bool first = true;
288 for (const auto& candidate : possible_union) {
289 if (first) {
290 first = false;
291 } else {
292 os << ", ";
293 }
294 os << candidate->ToString();
295 }
296 VLOG(2) << "Fired rule " << rule_name_ << " on {" << os.str() << "} to yield "
297 << new_candidate->ToString();
298#endif
299 ctxt->candidate_set->Add(*ctxt->dataflow_graph, new_candidate);
300 }
301 }
302 }
303 }
304}
305
306std::string TupleArgCombinerRuleNode::ToString() const {
307 return "TupleArgCombinerRule(" + rule_name_ + ")";
308}
309
310TupleArgCombinerRule::TupleArgCombinerRule(String rule_name) {
311 auto node = runtime::make_object<TupleArgCombinerRuleNode>();
312 node->rule_name_ = std::move(rule_name);
313 data_ = std::move(node);
314}
315
316TVM_REGISTER_NODE_TYPE(TupleProjCombinerRuleNode);
317
318void TupleProjCombinerRuleNode::VisitAttrs(AttrVisitor* v) {
319 // TODO(mbs)
320}
321
322void TupleProjCombinerRuleNode::AppendAllResults(AppendAllResultsContext* ctxt) const {
323 VLOG(1) << "running TupleProjCombinerRule(" << rule_name_ << ")";
324 // We already explored [0, first_new_index), so don't do it again.
325 for (size_t i = ctxt->candidate_set->first_new_index(); i < ctxt->candidate_set->size(); ++i) {
326 CandidatePartition base = ctxt->candidate_set->at(i);
327 for (PostDfsIndex index : base->sub_graph_->output_) {
328 auto node = ctxt->dataflow_graph->index_to_node(index);
329 if (node->ref().as<TupleGetItemNode>()) {
330 IndexSet index_set(ctxt->dataflow_graph->size(), {node->index_});
331 SubGraph sub_graph(*ctxt->dataflow_graph, std::move(index_set), kInjective, "proj");
332 CandidatePartition proj_candidate("", std::move(sub_graph), base->spec_);
333 CandidatePartition new_candidate =
334 base.DisjointUnion(*ctxt->dataflow_graph, proj_candidate);
335 VLOG(2) << "Fired rule " << rule_name_ << " on " << proj_candidate->ToString() << " and "
336 << base->ToString() << " to yield " << new_candidate->ToString();
337 ctxt->candidate_set->Add(*ctxt->dataflow_graph, new_candidate);
338 }
339 }
340 }
341}
342
343std::string TupleProjCombinerRuleNode::ToString() const {
344 return "TupleProjCombinerRule(" + rule_name_ + ")";
345}
346
347TupleProjCombinerRule::TupleProjCombinerRule(String rule_name) {
348 auto node = runtime::make_object<TupleProjCombinerRuleNode>();
349 node->rule_name_ = std::move(rule_name);
350 data_ = std::move(node);
351}
352
353TVM_REGISTER_NODE_TYPE(ConstantCombinerRuleNode);
354
355void ConstantCombinerRuleNode::VisitAttrs(AttrVisitor* v) {
356 // TODO(mbs)
357}
358
359void ConstantCombinerRuleNode::AppendAllResults(AppendAllResultsContext* ctxt) const {
360 VLOG(1) << "running ConstantCombinerRule(" << rule_name_ << ")";
361 // We already explored [0, first_new_index), so don't do it again.
362 for (size_t i = ctxt->candidate_set->first_new_index(); i < ctxt->candidate_set->size(); ++i) {
363 CandidatePartition base = ctxt->candidate_set->at(i);
364 IndexSet new_constants(ctxt->dataflow_graph->size());
365 for (PostDfsIndex index : base->sub_graph_->input_) {
366 auto node = ctxt->dataflow_graph->index_to_node(index);
367 if (node->ref().as<ConstantNode>()) {
368 new_constants.Add(index);
369 }
370 }
371 if (!new_constants.IsZero()) {
372 SubGraph sub_graph(*ctxt->dataflow_graph, new_constants, kElemWise, "const");
373 CandidatePartition new_const_candidate("", std::move(sub_graph), base->spec_);
374 CandidatePartition new_candidate =
375 base.DisjointUnion(*ctxt->dataflow_graph, new_const_candidate);
376 VLOG(2) << "Fired rule " << rule_name_ << " on " << new_const_candidate->ToString() << " and "
377 << base->ToString() << " to yield " << new_candidate->ToString();
378 ctxt->candidate_set->Add(*ctxt->dataflow_graph, new_candidate);
379 }
380 }
381}
382
383std::string ConstantCombinerRuleNode::ToString() const {
384 return "ConstantCombinerRule(" + rule_name_ + ")";
385}
386
387ConstantCombinerRule::ConstantCombinerRule(String rule_name) {
388 auto node = runtime::make_object<ConstantCombinerRuleNode>();
389 node->rule_name_ = std::move(rule_name);
390 data_ = std::move(node);
391}
392
393} // namespace collage
394} // namespace relay
395} // namespace tvm
396