1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/collage/partition_rule.h |
22 | * \brief Compositional partitioning rules. |
23 | */ |
24 | |
25 | #ifndef TVM_RELAY_COLLAGE_PARTITION_RULE_H_ |
26 | #define TVM_RELAY_COLLAGE_PARTITION_RULE_H_ |
27 | |
28 | #include <tvm/relay/dataflow_pattern.h> |
29 | #include <tvm/relay/expr.h> |
30 | |
31 | #include <string> |
32 | #include <vector> |
33 | |
34 | #include "../printer/doc.h" |
35 | #include "./candidate_partition.h" |
36 | #include "./combiner_rule.h" |
37 | #include "./sub_graph.h" |
38 | |
39 | namespace tvm { |
40 | namespace relay { |
41 | namespace collage { |
42 | |
43 | /*! |
44 | * \brief Type of function to check if a matched sub-expression should be accepted by a rule. This |
45 | * can be used to, eg, reject operators of unsupported shape or dtype, or otherwise implement rules |
46 | * which are difficult to express in the dataflow pattern language directly. |
47 | */ |
48 | using TPatternPredicate = TypedPackedFunc<bool(const Expr& matched_sub_expr)>; |
49 | |
50 | /*! |
51 | * \brief The default pattern predicate. Always returns true. |
52 | */ |
53 | bool DefaultPatternPredicate(const Expr& matched_sub_expr); |
54 | |
55 | /*! |
56 | * \brief Base class of all partition rules. |
57 | * |
58 | * A \p PartitionRule describes how to find a set of \p CandidatePartitions for a \p DataflowGraph. |
59 | * The candidates are allowed to overlap, and ultimately it is the job of the Collage searcher to |
60 | * find a selection of candidates which covers the whole Relay expression without overlap. Partition |
61 | * rules are paired with their \p Target and other 'top level' configuration in a \p PartitionSpec. |
62 | * |
63 | * We provide a set of 'base' partition rules which produce candidates from the dataflow graph |
64 | * directly. We also provide a set of 'combinator' partition rules which can produce new candidates |
65 | * from the results of an arbitrary sub-rule or sub-rules. By mixing these base and combinator |
66 | * rules we can express a wide variety of partition strategies and encoding conventions. |
67 | * |
68 | * There may be many thousands of candidates in flight during the Collage search. We take care to |
69 | * defer constructing or rewriting Relay expressions until absolutely necessary. We only pay for |
70 | * extracting a function to represent a candidate when we need to measure it's cost. And we only |
71 | * pay for rewriting the overall Relay expression to commit to a partitioning when the Collage |
72 | * search has completed. |
73 | * |
74 | * The base rules implemented so far: |
75 | * - \p DFPatternPartitionRule: Given a \p DFPattern and expression predicate, produces a candidate |
76 | * for every sub-graph matched by the pattern and predicate. Unlike the \p PatternRewriter, |
77 | * candidates are free to overlap. Used to bring BYOC patterns into the Collage framework. |
78 | * - \p OpCallByKindPartitionRule: Uses the "TOpPattern" attribute provided for every Relay |
79 | * operator to produce a candidate for every call to a 'fusable Relay operator'. Used to |
80 | * look ahead to how TVM will fuse sub-graphs. |
81 | * |
82 | * The combinator rules implemented so far: |
83 | * - \p CompositePartitionRule: Indicates all candidates matched by the sub-rule should be wrapped |
84 | * by a "Composite" function. The "Composite" name is taken from the rule name. Used to indicate |
85 | * Relay operators (or groups of Relay operators) should be mapped to target-specific operators, |
86 | * both for BYOC and TVM external library integrations. |
87 | * - \p PrimitivePartitionRule: Indicates all candidates matched by the sub-rule should be wrapped |
88 | * by a "Primitive" function, possibly with an additional "Compiler" attribute. Used to |
89 | * delineate a partition (or kernel). |
90 | * - \p UnionPartitionRule: Simply unions all the candidates from all sub-rules together. Used to |
91 | * combine individual \p DFPatternPartitionRules. |
92 | * - \p CombinePartitionRule: Given a sub-rule and a list of 'combiner' rules, finds |
93 | * all possible ways of combining the sub-rule's candidates to yield even larger candidates. |
94 | * Note that the sub-rule's candidates may also be directly included in the results. The |
95 | * 'combiner' rules allow combining by \p OpPatternKinds, combining the arguments to tuples |
96 | * which themselves are arguments to Relay operator calls, and so on. This rule is intended to |
97 | * mimic the existing TVM \p FuseOps pass, though: |
98 | * i) all candidates are found rather than just the largest, ii) the starting set of candidates |
99 | * can be provided by any other rule, and iii) we rely on \p SubGraph validity checking to weed |
100 | * out infeasible candidates. |
101 | * - \p OnlyValidPartitionRule: Given a \p SubGraphConfig, ignores candidates with 'invalid' |
102 | * sub-graphs. Used to limit the maximum candidate depth, the number of independent outputs, |
103 | * and whether intermediate 'taps' are allowed. |
104 | * - \p HostPartitionRule: Produces candidates for all Relay expressions which could be |
105 | * 'left behind' for execution by the host (eg on the VM). This rule lets us simplify the |
106 | * overall Collage search algorithm. |
107 | * |
108 | * (Though not yet implemented, we'd like to allow a combinator rule which will union candidate |
109 | * based on their 'anchor' operators. This can be used to implement 'vertical' and 'horizontal' |
110 | * partition on more primitive candidates. Note that the \p SubGraph machinery supports |
111 | * multiple-input and -output sub-graphs and their validation, so horizontal partition is easy |
112 | * implement.) |
113 | * |
114 | * Here are some typical ways to combine \p PartitionRules for different partition/fusion |
115 | * strategies: |
116 | * |
117 | * - Classic pattern-based BYOC with \p MergeComposite/AnnotateTarget/PartitionGraph passes: |
118 | * \code |
119 | * PrimitivePartitionRule |
120 | * OnlyValidPartitionRule |
121 | * CombinePartitionRule (with join-anything combiner rule) |
122 | * UnionPartitionRule |
123 | * CompositePartitionRule(label1) |
124 | * DFPatternPartitionRule(pattern1) |
125 | * : |
126 | * CompositePartitionRule(labeln) |
127 | * DFPatternPartitionRule(patternn) |
128 | * \endcode |
129 | * |
130 | * - "Consider this library implementation for these sub-expressions", using \p DFPatterns to |
131 | * pick out which Relay operators are supported: |
132 | * \code |
133 | * OnlyValidPartitionRule |
134 | * CombinePartitionRule (with default TVM combiner rules) |
135 | * UnionPartitionRule |
136 | * OpCallByKindPartitionRule |
137 | * CompositePartitionRule(lable1) |
138 | * DFPatternPartitionRule(pattern1) |
139 | * : |
140 | * CompositePartitionRule(lablen) |
141 | * DFPatternPartitionRule(patternn) |
142 | * \endcode |
143 | * |
144 | * - Classic TVM \p FuseOps |
145 | * \code |
146 | * PrimitivePartitionRule |
147 | * OnlyValidPartitionRule |
148 | * CombinePartitionRule (with default TVM combiner rules) |
149 | * OpCallByKindPartitionRule |
150 | * \endcode |
151 | * |
152 | * - "Just fuse what I tell you to fuse", using \p DFPatterns to directly select candidates: |
153 | * \code |
154 | * PrimitivePartitionRule |
155 | * OnlyValidPartitionRule |
156 | * UnionPartitionRule |
157 | * DFPatternPartitionRule(pattern1) |
158 | * : |
159 | * DFPatternPartitionRule(patternn) |
160 | * \endcode |
161 | */ |
162 | class PartitionRuleNode : public Object { |
163 | public: |
164 | /*! |
165 | * \brief A unique (over all rules for the same target) name for the rule. Rule names are |
166 | * combined and captured with \p PartitionCandidate rule names for debuggability and |
167 | * explainability. Some rules will copy the rule name into function attributes. |
168 | * |
169 | */ |
170 | String rule_name_; |
171 | |
172 | void VisitAttrs(AttrVisitor* v); |
173 | |
174 | /*! |
175 | * \brief Returns all the possible candidate partitions according to this rule for the overall |
176 | * expression corresponding to \p dataflow_graph. The candidates will generally have unknown |
177 | * target and cost: the target will be filled in by the \p PartitionSpec, while the cost will |
178 | * be filled in lazily. |
179 | */ |
180 | virtual std::vector<CandidatePartition> AllCandidates(const DataflowGraph& dataflow_graph, |
181 | const PartitionSpec& spec) const; |
182 | |
183 | std::string ToString() const; |
184 | Doc ToDoc() const; |
185 | |
186 | protected: |
187 | virtual void AppendBodyItems(std::vector<Doc>* body_items) const; |
188 | |
189 | public: |
190 | static constexpr const char* _type_key = "relay.collage.PartitionRule" ; |
191 | static constexpr const uint32_t _type_child_slots = 10; |
192 | TVM_DECLARE_BASE_OBJECT_INFO(PartitionRuleNode, Object); |
193 | }; |
194 | |
195 | class PartitionRule : public ObjectRef { |
196 | public: |
197 | explicit PartitionRule(String rule_name); |
198 | |
199 | TVM_DEFINE_OBJECT_REF_METHODS(PartitionRule, ObjectRef, PartitionRuleNode); |
200 | }; |
201 | |
202 | /*! |
203 | * \brief Partition rule which fires on all sub-expressions matching a dataflow-pattern and pattern |
204 | * predicate. It is valid for matching candidates to overlap. |
205 | */ |
206 | class DFPatternPartitionRuleNode : public PartitionRuleNode { |
207 | public: |
208 | /*! |
209 | * \brief Relay pattern. |
210 | */ |
211 | DFPattern pattern_; |
212 | |
213 | /*! |
214 | * \brief Predicate on matched sub-expression to decide if partition rule should fire. |
215 | */ |
216 | TPatternPredicate predicate_; |
217 | |
218 | void VisitAttrs(AttrVisitor* v); |
219 | |
220 | std::vector<CandidatePartition> AllCandidates(const DataflowGraph& dataflow_graph, |
221 | const PartitionSpec& spec) const override; |
222 | |
223 | void AppendBodyItems(std::vector<Doc>* body_items) const override; |
224 | |
225 | static constexpr const char* _type_key = "relay.collage.DFPatternPartitionRule" ; |
226 | TVM_DECLARE_FINAL_OBJECT_INFO(DFPatternPartitionRuleNode, PartitionRuleNode); |
227 | }; |
228 | |
229 | class DFPatternPartitionRule : public PartitionRule { |
230 | public: |
231 | DFPatternPartitionRule(String rule_name, DFPattern pattern, |
232 | TPatternPredicate predicate = DefaultPatternPredicate); |
233 | |
234 | TVM_DEFINE_OBJECT_REF_METHODS(DFPatternPartitionRule, PartitionRule, DFPatternPartitionRuleNode); |
235 | }; |
236 | |
237 | /*! |
238 | * \brief Partition rule which wraps candidates within a function with the "Composite" attribute |
239 | * bound to the given rule name. |
240 | * |
241 | * This is the standard way by which operators or operator groups are tagged as being supported |
242 | * by a particular externally provided function. It is up to the BYOC lowering function to |
243 | * recognize the "Composite" name and emit the appropriate code or call. |
244 | */ |
245 | class CompositePartitionRuleNode : public PartitionRuleNode { |
246 | public: |
247 | /*! \brief The sub-partition rule. */ |
248 | PartitionRule sub_rule_; |
249 | |
250 | void VisitAttrs(AttrVisitor* v); |
251 | |
252 | std::vector<CandidatePartition> AllCandidates(const DataflowGraph& dataflow_graph, |
253 | const PartitionSpec& spec) const override; |
254 | |
255 | void AppendBodyItems(std::vector<Doc>* body_items) const override; |
256 | |
257 | static constexpr const char* _type_key = "relay.collage.CompositePartitionRule" ; |
258 | TVM_DECLARE_FINAL_OBJECT_INFO(CompositePartitionRuleNode, PartitionRuleNode); |
259 | }; |
260 | |
261 | class CompositePartitionRule : public PartitionRule { |
262 | public: |
263 | CompositePartitionRule(String rule_name, PartitionRule sub_rule); |
264 | |
265 | TVM_DEFINE_OBJECT_REF_METHODS(CompositePartitionRule, PartitionRule, CompositePartitionRuleNode); |
266 | }; |
267 | |
268 | /*! |
269 | * \brief Partition rule which wraps candidates within a function with the "Primitive" attribute |
270 | * bound to 1. If the partition spec target(s) have the "compiler" attribute then that name is |
271 | * also added to the function as a "Compiler" attribute. |
272 | * |
273 | * This is the standard way by which sub-graphs are marked as being in a 'partition' who's |
274 | * compilation will be managed by an external BYOC toolchain. It can also be used to mark |
275 | * sub-graphs for lowering to a single kernel by the built-in TVM lowering machinery. |
276 | */ |
277 | class PrimitivePartitionRuleNode : public PartitionRuleNode { |
278 | public: |
279 | /*! \brief The sub-partition rule. */ |
280 | PartitionRule sub_rule_; |
281 | |
282 | void VisitAttrs(AttrVisitor* v); |
283 | |
284 | std::vector<CandidatePartition> AllCandidates(const DataflowGraph& dataflow_graph, |
285 | const PartitionSpec& spec) const override; |
286 | |
287 | void AppendBodyItems(std::vector<Doc>* body_items) const override; |
288 | |
289 | static constexpr const char* _type_key = "relay.collage.PrimitivePartitionRule" ; |
290 | TVM_DECLARE_FINAL_OBJECT_INFO(PrimitivePartitionRuleNode, PartitionRuleNode); |
291 | }; |
292 | |
293 | class PrimitivePartitionRule : public PartitionRule { |
294 | public: |
295 | PrimitivePartitionRule(String rule_name, PartitionRule sub_rule); |
296 | |
297 | TVM_DEFINE_OBJECT_REF_METHODS(PrimitivePartitionRule, PartitionRule, PrimitivePartitionRuleNode); |
298 | }; |
299 | |
300 | /*! |
301 | * \brief Partition rule which simply unions all matches from all sub-partition rules. |
302 | * |
303 | * This can be used to combine the results of a set of, eg, DFPatternPartitionRules. |
304 | */ |
305 | class UnionPartitionRuleNode : public PartitionRuleNode { |
306 | public: |
307 | Array<PartitionRule> sub_rules_; |
308 | |
309 | void VisitAttrs(AttrVisitor* v); |
310 | |
311 | std::vector<CandidatePartition> AllCandidates(const DataflowGraph& dataflow_graph, |
312 | const PartitionSpec& spec) const override; |
313 | |
314 | void AppendBodyItems(std::vector<Doc>* body_items) const override; |
315 | |
316 | static constexpr const char* _type_key = "relay.collage.UnionPartitionRule" ; |
317 | TVM_DECLARE_FINAL_OBJECT_INFO(UnionPartitionRuleNode, PartitionRuleNode); |
318 | }; |
319 | |
320 | class UnionPartitionRule : public PartitionRule { |
321 | public: |
322 | UnionPartitionRule(String rule_name, Array<PartitionRule> sub_rules); |
323 | |
324 | TVM_DEFINE_OBJECT_REF_METHODS(UnionPartitionRule, PartitionRule, UnionPartitionRuleNode) |
325 | }; |
326 | |
327 | /* |
328 | *! \brief Partition rule which places calls to Relay operators with a "TOpPattern" attribute of |
329 | * \p kOutEWiseFusable or less in their own singleton sub-graph. No other Relay sub-expressions |
330 | * (such as tuples or tuple projection) are selected, and it is up to outer partition rules to |
331 | * account for them. |
332 | */ |
333 | class OpCallByKindPartitionRuleNode : public PartitionRuleNode { |
334 | public: |
335 | void VisitAttrs(AttrVisitor* v); |
336 | |
337 | std::vector<CandidatePartition> AllCandidates(const DataflowGraph& dataflow_graph, |
338 | const PartitionSpec& spec) const override; |
339 | |
340 | void AppendBodyItems(std::vector<Doc>* body_items) const override; |
341 | |
342 | static constexpr const char* _type_key = "relay.collage.OpCallByKindPartitionRule" ; |
343 | TVM_DECLARE_FINAL_OBJECT_INFO(OpCallByKindPartitionRuleNode, PartitionRuleNode); |
344 | }; |
345 | |
346 | class OpCallByKindPartitionRule : public PartitionRule { |
347 | public: |
348 | explicit OpCallByKindPartitionRule(String rule_name); |
349 | |
350 | TVM_DEFINE_OBJECT_REF_METHODS(OpCallByKindPartitionRule, PartitionRule, |
351 | OpCallByKindPartitionRuleNode); |
352 | }; |
353 | |
354 | /*! |
355 | * \brief Partition rule which combines sub-graphs to exploit optimizations commonly available in |
356 | * backends (including the TVM lowering backend). Those optimization rules are in turn described by |
357 | * one or more primitive \p CombinerRules. |
358 | * |
359 | * For TVM these primitive combiner rules are guided by the \p OpPatternKind associated with every |
360 | * sub-graph. That in turn is the maximum of the kind of each expression node in the sub-graph, |
361 | * using the rules: |
362 | * - Constants are \p kElemwise. |
363 | * - A call to a Relay operator has the kind of its callee. |
364 | * - Tuple construction and projection are injective provided all tuple fields are of tensor type. |
365 | * - All other sub-expressions are opaque. |
366 | * |
367 | * The available \p OpPatternKinds (and our abbreviations for them) are: |
368 | * - E: kElemWise, eg nn.relu |
369 | * - B: kBroadcast, eg add |
370 | * - I: kInjective, eg concatenate |
371 | * - R: kCommReduce, eg sum |
372 | * - A: kOutEWiseFusable, eg nn.conv2d (often called 'anchor nodes', hence the A abbreviation) |
373 | * - O: kOpaque, everything else |
374 | * (The kTuple kind is not used by this machinery.) |
375 | * |
376 | * Kinds are ordered as above from least- to most-constraining w.r.t. possible partition |
377 | * opportunities. When we write a kind abbreviation below we intend it to mean that kind *or less*. |
378 | * And when when write 'kl -> kr' we mean it to match a sub-expression of kind kr or less who's |
379 | * dataflow inputs are all of kind kl or less. |
380 | * |
381 | * We can then mimic the classic \p FuseOps TVM Pass with the following more primitive combiner |
382 | * rules: |
383 | * - Sub-groups cannot have taps. In the classic \p FuseOps pass taps are avoided by construction |
384 | * by always considering all node->dominator paths. Here we naively allow taps on all candidates, |
385 | * but reject them using SubGraph::IsValid with a SubGraphConfig with allow_taps = false. |
386 | * - Combine A -> B |
387 | * - Combine B -> R |
388 | * - Combine I -> I |
389 | * - Combine I -> tuple -> I. That is, if an I sub-graph has a tuple as input, and at least one |
390 | * tuple field can be provided by an I sub-graph exit, then both the tuple and all such fields |
391 | * may be joined. |
392 | gt* |
393 | * Note that \p FuseOps only considers the largest possible sub-graphs. However this partition rule |
394 | * considers all possibilities so as to 'make room' for other targets supplying other |
395 | * overlapping candidates. |
396 | * |
397 | * See combiner_rule.h for the more primitive combiner rules which implement the above. |
398 | */ |
399 | class CombinePartitionRuleNode : public PartitionRuleNode { |
400 | public: |
401 | /*! \brief The sub-rule supplying the initial set of candidates. */ |
402 | PartitionRule sub_rule_; |
403 | /*! \brief The more primitive rules to use to combine the candidates found by the above rule. */ |
404 | Array<CombinerRule> combiner_rules_; |
405 | /*! \brief Maximum max_depth for candidates. */ |
406 | size_t max_depth_; |
407 | |
408 | void VisitAttrs(AttrVisitor* v); |
409 | |
410 | std::vector<CandidatePartition> AllCandidates(const DataflowGraph& dataflow_graph, |
411 | const PartitionSpec& spec) const override; |
412 | |
413 | void AppendBodyItems(std::vector<Doc>* body_items) const override; |
414 | |
415 | public: |
416 | static constexpr const char* _type_key = "relay.collage.CombinePartitionRule" ; |
417 | TVM_DECLARE_FINAL_OBJECT_INFO(CombinePartitionRuleNode, PartitionRuleNode); |
418 | }; |
419 | |
420 | class CombinePartitionRule : public PartitionRule { |
421 | public: |
422 | CombinePartitionRule(String rule_name, PartitionRule sub_rule, Array<CombinerRule> combiner_rules, |
423 | size_t max_depth_); |
424 | |
425 | TVM_DEFINE_OBJECT_REF_METHODS(CombinePartitionRule, PartitionRule, CombinePartitionRuleNode); |
426 | }; |
427 | |
428 | /*! |
429 | * \brief Partition rules which keeps only candidates from the sub-rule whose sub-groups are valid |
430 | * w.r.t. the given \p SubGraphConfig. |
431 | */ |
432 | class OnlyValidPartitionRuleNode : public PartitionRuleNode { |
433 | public: |
434 | PartitionRule sub_rule_; |
435 | SubGraphConfig config_; |
436 | |
437 | void VisitAttrs(AttrVisitor* v); |
438 | |
439 | std::vector<CandidatePartition> AllCandidates(const DataflowGraph& dataflow_graph, |
440 | const PartitionSpec& spec) const override; |
441 | |
442 | void AppendBodyItems(std::vector<Doc>* body_items) const override; |
443 | |
444 | public: |
445 | static constexpr const char* _type_key = "relay.collage.OnlyValidPartitionRule" ; |
446 | TVM_DECLARE_FINAL_OBJECT_INFO(OnlyValidPartitionRuleNode, PartitionRuleNode); |
447 | }; |
448 | |
449 | class OnlyValidPartitionRule : public PartitionRule { |
450 | public: |
451 | OnlyValidPartitionRule(String rule_name, PartitionRule sub_rule, const SubGraphConfig& config); |
452 | |
453 | TVM_DEFINE_OBJECT_REF_METHODS(OnlyValidPartitionRule, PartitionRule, OnlyValidPartitionRuleNode); |
454 | }; |
455 | |
456 | /*! |
457 | * \brief Partition rule which selects nodes which can be 'left behind' to be executed by the host |
458 | * (eg on the VM). This includes most of the 'interstitial' Relay constructs, such a let bindings, |
459 | * operators on references, calls to non-operator functions, and so on. It can also include the |
460 | * construction of and projection from tuples which may not be supported within a partition. |
461 | */ |
462 | class HostPartitionRuleNode : public PartitionRuleNode { |
463 | public: |
464 | void VisitAttrs(AttrVisitor* v); |
465 | |
466 | std::vector<CandidatePartition> AllCandidates(const DataflowGraph& dataflow_graph, |
467 | const PartitionSpec& spec) const override; |
468 | |
469 | void AppendBodyItems(std::vector<Doc>* body_items) const override; |
470 | |
471 | public: |
472 | static constexpr const char* _type_key = "relay.collage.HostPartitionRule" ; |
473 | TVM_DECLARE_FINAL_OBJECT_INFO(HostPartitionRuleNode, PartitionRuleNode); |
474 | }; |
475 | |
476 | class HostPartitionRule : public PartitionRule { |
477 | public: |
478 | explicit HostPartitionRule(String rule_name); |
479 | |
480 | TVM_DEFINE_OBJECT_REF_METHODS(HostPartitionRule, PartitionRule, HostPartitionRuleNode); |
481 | }; |
482 | |
483 | } // namespace collage |
484 | } // namespace relay |
485 | } // namespace tvm |
486 | |
487 | #endif // TVM_RELAY_COLLAGE_PARTITION_RULE_H_ |
488 | |