1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/collage/combiner_rule.cc |
22 | * \brief Helpers for the \p CombinePartitionRule |
23 | */ |
24 | |
25 | #include "./combiner_rule.h" |
26 | |
27 | #include "./partition_spec.h" |
28 | |
29 | namespace tvm { |
30 | namespace relay { |
31 | namespace collage { |
32 | |
33 | TVM_REGISTER_NODE_TYPE(SimpleCombinerRuleNode); |
34 | |
35 | void SimpleCombinerRuleNode::VisitAttrs(AttrVisitor* v) { |
36 | // TODO(mbs) |
37 | } |
38 | |
39 | bool SimpleCombinerRuleNode::Fires(const DataflowGraph& dataflow_graph, |
40 | const CandidatePartition& upstream, |
41 | const CandidatePartition& downstream) const { |
42 | return false; |
43 | } |
44 | |
45 | std::string SimpleCombinerRuleNode::ToString() const { |
46 | return "SimpleCombinerRule(" + rule_name_ + ")" ; |
47 | } |
48 | |
49 | SimpleCombinerRule::SimpleCombinerRule(String rule_name) { |
50 | auto node = runtime::make_object<SimpleCombinerRuleNode>(); |
51 | node->rule_name_ = std::move(rule_name); |
52 | data_ = std::move(node); |
53 | } |
54 | |
55 | TVM_REGISTER_NODE_TYPE(ByKindSimpleCombinerRuleNode); |
56 | |
57 | void ByKindSimpleCombinerRuleNode::VisitAttrs(AttrVisitor* v) { |
58 | // TODO(mbs) |
59 | } |
60 | |
61 | bool ByKindSimpleCombinerRuleNode::Fires(const DataflowGraph& dataflow_graph, |
62 | const CandidatePartition& upstream, |
63 | const CandidatePartition& downstream) const { |
64 | return upstream->sub_graph_->kind_ <= upstream_kind_ && |
65 | downstream->sub_graph_->kind_ <= downstream_kind_; |
66 | } |
67 | |
68 | std::string ByKindSimpleCombinerRuleNode::ToString() const { |
69 | std::ostringstream os; |
70 | os << "ByKindSimpleCombinerRule(" << rule_name_ << ")" ; |
71 | return os.str(); |
72 | } |
73 | |
74 | ByKindSimpleCombinerRule::ByKindSimpleCombinerRule(OpPatternKind upstream_kind, |
75 | OpPatternKind downstream_kind) { |
76 | auto node = runtime::make_object<ByKindSimpleCombinerRuleNode>(); |
77 | String rule_name = KindToString(upstream_kind) + "->" + KindToString(downstream_kind); |
78 | node->rule_name_ = std::move(rule_name); |
79 | node->upstream_kind_ = upstream_kind; |
80 | node->downstream_kind_ = downstream_kind; |
81 | data_ = std::move(node); |
82 | } |
83 | |
84 | TVM_REGISTER_NODE_TYPE(CombinerRuleNode); |
85 | |
86 | void CombinerRuleNode::VisitAttrs(AttrVisitor* v) { |
87 | // TODO(mbs) |
88 | } |
89 | |
90 | void CombinerRuleNode::AppendAllResults(AppendAllResultsContext* ctxt) const {} |
91 | |
92 | std::string CombinerRuleNode::ToString() const { return "CombinerRuleNode(" + rule_name_ + ")" ; } |
93 | |
94 | CombinerRule::CombinerRule(String rule_name) { |
95 | auto node = runtime::make_object<CombinerRuleNode>(); |
96 | node->rule_name_ = std::move(rule_name); |
97 | data_ = std::move(node); |
98 | } |
99 | |
100 | TVM_REGISTER_NODE_TYPE(AllSimpleCombinerRuleNode); |
101 | |
102 | void AllSimpleCombinerRuleNode::VisitAttrs(AttrVisitor* v) { |
103 | // TODO(mbs) |
104 | } |
105 | |
106 | void AllSimpleCombinerRuleNode::AppendAllResults(AppendAllResultsContext* ctxt) const { |
107 | VLOG(1) << "running AllSimpleCombinerRule(" << rule_name_ << ")" ; |
108 | // Build map from post-dfs indices to the indices of candidates with corresponding entry node. |
109 | // NOTE: the index set is over candidate indices not post-dfs indices! |
110 | std::vector<IndexSet> entry_map(ctxt->dataflow_graph->size(), |
111 | IndexSet(ctxt->candidate_set->size())); |
112 | for (size_t i = 0; i < ctxt->candidate_set->size(); ++i) { |
113 | CandidatePartition candidate = ctxt->candidate_set->at(i); |
114 | for (PostDfsIndex entry_index : candidate->sub_graph_->entry_) { |
115 | entry_map[entry_index].Add(i); |
116 | } |
117 | } |
118 | |
119 | for (size_t i = 0; i < ctxt->candidate_set->size(); ++i) { |
120 | CandidatePartition upstream = ctxt->candidate_set->at(i); |
121 | // Narrow our search to just those candidates which could touch. |
122 | IndexSet possible_downstream(ctxt->candidate_set->size()); |
123 | for (PostDfsIndex output_index : upstream->sub_graph_->output_) { |
124 | possible_downstream = possible_downstream | entry_map[output_index]; |
125 | } |
126 | size_t start_j = |
127 | i < ctxt->candidate_set->first_new_index() ? ctxt->candidate_set->first_new_index() : 0; |
128 | for (size_t j : possible_downstream) { |
129 | if (i == j) { |
130 | continue; |
131 | } |
132 | if (i < start_j) { |
133 | // We already explored the cross-product of candidates [0, first_new_index), so don't |
134 | // do it again. |
135 | continue; |
136 | } |
137 | // Note that the rules are not commutative so we can't just ignore if j < i. |
138 | CandidatePartition downstream = ctxt->candidate_set->at(j); |
139 | if (ctxt->max_depth > 0 && |
140 | upstream->sub_graph_->depth_ + downstream->sub_graph_->depth_ > ctxt->max_depth) { |
141 | continue; |
142 | } |
143 | if (!upstream.AreTouching(*ctxt->dataflow_graph, downstream)) { |
144 | continue; |
145 | } |
146 | for (const auto& simple_rule : simple_rules_) { |
147 | if (simple_rule->Fires(*ctxt->dataflow_graph, upstream, downstream)) { |
148 | CandidatePartition new_candidate = |
149 | upstream.DisjointUnion(*ctxt->dataflow_graph, downstream); |
150 | VLOG(2) << "Fired " << simple_rule->rule_name_ << " on upstream candidate " |
151 | << upstream->ToString() << " and downstream candidate " << downstream->ToString() |
152 | << " to yield " << new_candidate->ToString(); |
153 | ctxt->candidate_set->Add(*ctxt->dataflow_graph, new_candidate); |
154 | } |
155 | } |
156 | } |
157 | } |
158 | } |
159 | |
160 | std::string AllSimpleCombinerRuleNode::ToString() const { |
161 | std::ostringstream os; |
162 | os << "AllSimpleCombinerRule(" << rule_name_; |
163 | for (const auto& simple : simple_rules_) { |
164 | os << ", " << simple->ToString(); |
165 | } |
166 | os << ")" ; |
167 | return os.str(); |
168 | } |
169 | |
170 | AllSimpleCombinerRule::AllSimpleCombinerRule(String rule_name, |
171 | Array<SimpleCombinerRule> simple_rules) { |
172 | auto node = runtime::make_object<AllSimpleCombinerRuleNode>(); |
173 | node->rule_name_ = std::move(rule_name); |
174 | node->simple_rules_ = std::move(simple_rules); |
175 | data_ = std::move(node); |
176 | } |
177 | |
178 | TVM_REGISTER_NODE_TYPE(TupleArgCombinerRuleNode); |
179 | |
180 | void TupleArgCombinerRuleNode::VisitAttrs(AttrVisitor* v) { |
181 | // TODO(mbs) |
182 | } |
183 | |
184 | void TupleArgCombinerRuleNode::AppendAllResults(AppendAllResultsContext* ctxt) const { |
185 | VLOG(1) << "running TupleArgCombinerRule(" << rule_name_ << ")" ; |
186 | // Build map from post-dfs index to the indices of injective candidates with corresponding entry |
187 | // node. NOTE: the index set is over candidate indices not post-dfs indices! |
188 | std::vector<IndexSet> exit_map(ctxt->dataflow_graph->size(), |
189 | IndexSet(ctxt->candidate_set->size())); |
190 | for (size_t i = 0; i < ctxt->candidate_set->size(); ++i) { |
191 | CandidatePartition candidate = ctxt->candidate_set->at(i); |
192 | if (candidate->sub_graph_->kind_ > kInjective) { |
193 | continue; |
194 | } |
195 | for (PostDfsIndex exit_index : candidate->sub_graph_->exit_) { |
196 | exit_map[exit_index].Add(i); |
197 | } |
198 | } |
199 | |
200 | // The two-step I -> tuple -> I rule. |
201 | // Look all possible tuple consumers... |
202 | for (size_t i = 0; i < ctxt->candidate_set->size(); ++i) { |
203 | CandidatePartition tuple_consumer_candidate = ctxt->candidate_set->at(i); |
204 | if (tuple_consumer_candidate->sub_graph_->kind_ > kInjective) { |
205 | continue; |
206 | } |
207 | // For all possible tuples feeding into candidate... |
208 | for (PostDfsIndex input_index : tuple_consumer_candidate->sub_graph_->input_) { |
209 | auto node = ctxt->dataflow_graph->index_to_node(input_index); |
210 | Expr sub_expr = node->ref(); |
211 | const auto* tuple_node = sub_expr.as<TupleNode>(); |
212 | if (tuple_node == nullptr) { |
213 | continue; |
214 | } |
215 | // The tuple_consumer_candidate candidate consumes (at least one) tuple, eg as an argument |
216 | // to an operator. |
217 | // eg: concatenate((field1, ..., fieldn)) |
218 | auto tuple_dataflow_node = ctxt->dataflow_graph->item_to_node(tuple_node); |
219 | |
220 | // Collect all the possible unions. There may be more than one if different candidates |
221 | // could supply the same tuple field. |
222 | std::vector<std::vector<CandidatePartition>> all_possible_unions; |
223 | |
224 | // Obviously we must include the consumer. |
225 | all_possible_unions.emplace_back(); |
226 | all_possible_unions.back().emplace_back(tuple_consumer_candidate); |
227 | |
228 | // We must include the tuple itself. |
229 | SubGraph tuple_sub_graph(*ctxt->dataflow_graph, |
230 | IndexSet(ctxt->dataflow_graph->size(), {node->index_}), kInjective, |
231 | "tuple" ); |
232 | CandidatePartition tuple_candidate("" , std::move(tuple_sub_graph), |
233 | tuple_consumer_candidate->partition_spec()); |
234 | all_possible_unions.back().emplace_back(std::move(tuple_candidate)); |
235 | |
236 | // For all tuple fields... |
237 | bool all_tuple_fields_have_producer = true; |
238 | for (auto* tuple_field_dataflow_node : tuple_dataflow_node->inputs_) { |
239 | // Collect all the candidates which could produce this tuple field. |
240 | std::vector<CandidatePartition> to_appends; |
241 | size_t start_j = |
242 | i < ctxt->candidate_set->first_new_index() ? ctxt->candidate_set->first_new_index() : 0; |
243 | for (size_t j : exit_map[tuple_field_dataflow_node->index_]) { |
244 | if (i == j) { |
245 | continue; |
246 | } |
247 | if (i < start_j) { |
248 | // We already explored the cross-product of candidates [0, first_new_index), so don't |
249 | // do it again. |
250 | continue; |
251 | } |
252 | CandidatePartition tuple_field_producer = ctxt->candidate_set->at(j); |
253 | // The tuple_field_producer candidate can provide this tuple field. |
254 | // eg concatenate((..., producer, ...)) |
255 | to_appends.emplace_back(tuple_field_producer); |
256 | } |
257 | if (to_appends.empty()) { |
258 | // At least one of the tuple's fields does not have a producer candidate we can |
259 | // union in, so we need to give up. |
260 | all_tuple_fields_have_producer = false; |
261 | break; |
262 | } else { |
263 | // If to_appends = [A, B] and we already have possible unions [C, D] and [E, F] then |
264 | // the new possible unions are [C, D, A], [C, D, B], [E, F, A] and [E, F, B]. |
265 | std::vector<std::vector<CandidatePartition>> new_all_possible_unions; |
266 | for (const auto& to_append : to_appends) { |
267 | for (const auto& possible_union : all_possible_unions) { |
268 | new_all_possible_unions.emplace_back(possible_union); |
269 | new_all_possible_unions.back().emplace_back(to_append); |
270 | } |
271 | } |
272 | all_possible_unions = std::move(new_all_possible_unions); |
273 | } |
274 | } |
275 | |
276 | if (!all_tuple_fields_have_producer) { |
277 | continue; |
278 | } |
279 | |
280 | // Actually build the candidates which union according to all_possible_unions. |
281 | for (const auto& possible_union : all_possible_unions) { |
282 | if (possible_union.size() > 2) { |
283 | CandidatePartition new_candidate = |
284 | CandidatePartition::DisjointUnion(*ctxt->dataflow_graph, possible_union); |
285 | #if TVM_LOG_DEBUG |
286 | std::ostringstream os; |
287 | bool first = true; |
288 | for (const auto& candidate : possible_union) { |
289 | if (first) { |
290 | first = false; |
291 | } else { |
292 | os << ", " ; |
293 | } |
294 | os << candidate->ToString(); |
295 | } |
296 | VLOG(2) << "Fired rule " << rule_name_ << " on {" << os.str() << "} to yield " |
297 | << new_candidate->ToString(); |
298 | #endif |
299 | ctxt->candidate_set->Add(*ctxt->dataflow_graph, new_candidate); |
300 | } |
301 | } |
302 | } |
303 | } |
304 | } |
305 | |
306 | std::string TupleArgCombinerRuleNode::ToString() const { |
307 | return "TupleArgCombinerRule(" + rule_name_ + ")" ; |
308 | } |
309 | |
310 | TupleArgCombinerRule::TupleArgCombinerRule(String rule_name) { |
311 | auto node = runtime::make_object<TupleArgCombinerRuleNode>(); |
312 | node->rule_name_ = std::move(rule_name); |
313 | data_ = std::move(node); |
314 | } |
315 | |
316 | TVM_REGISTER_NODE_TYPE(TupleProjCombinerRuleNode); |
317 | |
318 | void TupleProjCombinerRuleNode::VisitAttrs(AttrVisitor* v) { |
319 | // TODO(mbs) |
320 | } |
321 | |
322 | void TupleProjCombinerRuleNode::AppendAllResults(AppendAllResultsContext* ctxt) const { |
323 | VLOG(1) << "running TupleProjCombinerRule(" << rule_name_ << ")" ; |
324 | // We already explored [0, first_new_index), so don't do it again. |
325 | for (size_t i = ctxt->candidate_set->first_new_index(); i < ctxt->candidate_set->size(); ++i) { |
326 | CandidatePartition base = ctxt->candidate_set->at(i); |
327 | for (PostDfsIndex index : base->sub_graph_->output_) { |
328 | auto node = ctxt->dataflow_graph->index_to_node(index); |
329 | if (node->ref().as<TupleGetItemNode>()) { |
330 | IndexSet index_set(ctxt->dataflow_graph->size(), {node->index_}); |
331 | SubGraph sub_graph(*ctxt->dataflow_graph, std::move(index_set), kInjective, "proj" ); |
332 | CandidatePartition proj_candidate("" , std::move(sub_graph), base->spec_); |
333 | CandidatePartition new_candidate = |
334 | base.DisjointUnion(*ctxt->dataflow_graph, proj_candidate); |
335 | VLOG(2) << "Fired rule " << rule_name_ << " on " << proj_candidate->ToString() << " and " |
336 | << base->ToString() << " to yield " << new_candidate->ToString(); |
337 | ctxt->candidate_set->Add(*ctxt->dataflow_graph, new_candidate); |
338 | } |
339 | } |
340 | } |
341 | } |
342 | |
343 | std::string TupleProjCombinerRuleNode::ToString() const { |
344 | return "TupleProjCombinerRule(" + rule_name_ + ")" ; |
345 | } |
346 | |
347 | TupleProjCombinerRule::TupleProjCombinerRule(String rule_name) { |
348 | auto node = runtime::make_object<TupleProjCombinerRuleNode>(); |
349 | node->rule_name_ = std::move(rule_name); |
350 | data_ = std::move(node); |
351 | } |
352 | |
353 | TVM_REGISTER_NODE_TYPE(ConstantCombinerRuleNode); |
354 | |
355 | void ConstantCombinerRuleNode::VisitAttrs(AttrVisitor* v) { |
356 | // TODO(mbs) |
357 | } |
358 | |
359 | void ConstantCombinerRuleNode::AppendAllResults(AppendAllResultsContext* ctxt) const { |
360 | VLOG(1) << "running ConstantCombinerRule(" << rule_name_ << ")" ; |
361 | // We already explored [0, first_new_index), so don't do it again. |
362 | for (size_t i = ctxt->candidate_set->first_new_index(); i < ctxt->candidate_set->size(); ++i) { |
363 | CandidatePartition base = ctxt->candidate_set->at(i); |
364 | IndexSet new_constants(ctxt->dataflow_graph->size()); |
365 | for (PostDfsIndex index : base->sub_graph_->input_) { |
366 | auto node = ctxt->dataflow_graph->index_to_node(index); |
367 | if (node->ref().as<ConstantNode>()) { |
368 | new_constants.Add(index); |
369 | } |
370 | } |
371 | if (!new_constants.IsZero()) { |
372 | SubGraph sub_graph(*ctxt->dataflow_graph, new_constants, kElemWise, "const" ); |
373 | CandidatePartition new_const_candidate("" , std::move(sub_graph), base->spec_); |
374 | CandidatePartition new_candidate = |
375 | base.DisjointUnion(*ctxt->dataflow_graph, new_const_candidate); |
376 | VLOG(2) << "Fired rule " << rule_name_ << " on " << new_const_candidate->ToString() << " and " |
377 | << base->ToString() << " to yield " << new_candidate->ToString(); |
378 | ctxt->candidate_set->Add(*ctxt->dataflow_graph, new_candidate); |
379 | } |
380 | } |
381 | } |
382 | |
383 | std::string ConstantCombinerRuleNode::ToString() const { |
384 | return "ConstantCombinerRule(" + rule_name_ + ")" ; |
385 | } |
386 | |
387 | ConstantCombinerRule::ConstantCombinerRule(String rule_name) { |
388 | auto node = runtime::make_object<ConstantCombinerRuleNode>(); |
389 | node->rule_name_ = std::move(rule_name); |
390 | data_ = std::move(node); |
391 | } |
392 | |
393 | } // namespace collage |
394 | } // namespace relay |
395 | } // namespace tvm |
396 | |