1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/collage/partition_rule.h
22 * \brief Compositional partitioning rules.
23 */
24
25#ifndef TVM_RELAY_COLLAGE_PARTITION_RULE_H_
26#define TVM_RELAY_COLLAGE_PARTITION_RULE_H_
27
28#include <tvm/relay/dataflow_pattern.h>
29#include <tvm/relay/expr.h>
30
31#include <string>
32#include <vector>
33
34#include "../printer/doc.h"
35#include "./candidate_partition.h"
36#include "./combiner_rule.h"
37#include "./sub_graph.h"
38
39namespace tvm {
40namespace relay {
41namespace collage {
42
43/*!
44 * \brief Type of function to check if a matched sub-expression should be accepted by a rule. This
45 * can be used to, eg, reject operators of unsupported shape or dtype, or otherwise implement rules
46 * which are difficult to express in the dataflow pattern language directly.
47 */
48using TPatternPredicate = TypedPackedFunc<bool(const Expr& matched_sub_expr)>;
49
50/*!
51 * \brief The default pattern predicate. Always returns true.
52 */
53bool DefaultPatternPredicate(const Expr& matched_sub_expr);
54
55/*!
56 * \brief Base class of all partition rules.
57 *
58 * A \p PartitionRule describes how to find a set of \p CandidatePartitions for a \p DataflowGraph.
59 * The candidates are allowed to overlap, and ultimately it is the job of the Collage searcher to
60 * find a selection of candidates which covers the whole Relay expression without overlap. Partition
61 * rules are paired with their \p Target and other 'top level' configuration in a \p PartitionSpec.
62 *
63 * We provide a set of 'base' partition rules which produce candidates from the dataflow graph
64 * directly. We also provide a set of 'combinator' partition rules which can produce new candidates
65 * from the results of an arbitrary sub-rule or sub-rules. By mixing these base and combinator
66 * rules we can express a wide variety of partition strategies and encoding conventions.
67 *
68 * There may be many thousands of candidates in flight during the Collage search. We take care to
69 * defer constructing or rewriting Relay expressions until absolutely necessary. We only pay for
70 * extracting a function to represent a candidate when we need to measure it's cost. And we only
71 * pay for rewriting the overall Relay expression to commit to a partitioning when the Collage
72 * search has completed.
73 *
74 * The base rules implemented so far:
75 * - \p DFPatternPartitionRule: Given a \p DFPattern and expression predicate, produces a candidate
76 * for every sub-graph matched by the pattern and predicate. Unlike the \p PatternRewriter,
77 * candidates are free to overlap. Used to bring BYOC patterns into the Collage framework.
78 * - \p OpCallByKindPartitionRule: Uses the "TOpPattern" attribute provided for every Relay
79 * operator to produce a candidate for every call to a 'fusable Relay operator'. Used to
80 * look ahead to how TVM will fuse sub-graphs.
81 *
82 * The combinator rules implemented so far:
83 * - \p CompositePartitionRule: Indicates all candidates matched by the sub-rule should be wrapped
84 * by a "Composite" function. The "Composite" name is taken from the rule name. Used to indicate
85 * Relay operators (or groups of Relay operators) should be mapped to target-specific operators,
86 * both for BYOC and TVM external library integrations.
87 * - \p PrimitivePartitionRule: Indicates all candidates matched by the sub-rule should be wrapped
88 * by a "Primitive" function, possibly with an additional "Compiler" attribute. Used to
89 * delineate a partition (or kernel).
90 * - \p UnionPartitionRule: Simply unions all the candidates from all sub-rules together. Used to
91 * combine individual \p DFPatternPartitionRules.
92 * - \p CombinePartitionRule: Given a sub-rule and a list of 'combiner' rules, finds
93 * all possible ways of combining the sub-rule's candidates to yield even larger candidates.
94 * Note that the sub-rule's candidates may also be directly included in the results. The
95 * 'combiner' rules allow combining by \p OpPatternKinds, combining the arguments to tuples
96 * which themselves are arguments to Relay operator calls, and so on. This rule is intended to
97 * mimic the existing TVM \p FuseOps pass, though:
98 * i) all candidates are found rather than just the largest, ii) the starting set of candidates
99 * can be provided by any other rule, and iii) we rely on \p SubGraph validity checking to weed
100 * out infeasible candidates.
101 * - \p OnlyValidPartitionRule: Given a \p SubGraphConfig, ignores candidates with 'invalid'
102 * sub-graphs. Used to limit the maximum candidate depth, the number of independent outputs,
103 * and whether intermediate 'taps' are allowed.
104 * - \p HostPartitionRule: Produces candidates for all Relay expressions which could be
105 * 'left behind' for execution by the host (eg on the VM). This rule lets us simplify the
106 * overall Collage search algorithm.
107 *
108 * (Though not yet implemented, we'd like to allow a combinator rule which will union candidate
109 * based on their 'anchor' operators. This can be used to implement 'vertical' and 'horizontal'
110 * partition on more primitive candidates. Note that the \p SubGraph machinery supports
111 * multiple-input and -output sub-graphs and their validation, so horizontal partition is easy
112 * implement.)
113 *
114 * Here are some typical ways to combine \p PartitionRules for different partition/fusion
115 * strategies:
116 *
117 * - Classic pattern-based BYOC with \p MergeComposite/AnnotateTarget/PartitionGraph passes:
118 * \code
119 * PrimitivePartitionRule
120 * OnlyValidPartitionRule
121 * CombinePartitionRule (with join-anything combiner rule)
122 * UnionPartitionRule
123 * CompositePartitionRule(label1)
124 * DFPatternPartitionRule(pattern1)
125 * :
126 * CompositePartitionRule(labeln)
127 * DFPatternPartitionRule(patternn)
128 * \endcode
129 *
130 * - "Consider this library implementation for these sub-expressions", using \p DFPatterns to
131 * pick out which Relay operators are supported:
132 * \code
133 * OnlyValidPartitionRule
134 * CombinePartitionRule (with default TVM combiner rules)
135 * UnionPartitionRule
136 * OpCallByKindPartitionRule
137 * CompositePartitionRule(lable1)
138 * DFPatternPartitionRule(pattern1)
139 * :
140 * CompositePartitionRule(lablen)
141 * DFPatternPartitionRule(patternn)
142 * \endcode
143 *
144 * - Classic TVM \p FuseOps
145 * \code
146 * PrimitivePartitionRule
147 * OnlyValidPartitionRule
148 * CombinePartitionRule (with default TVM combiner rules)
149 * OpCallByKindPartitionRule
150 * \endcode
151 *
152 * - "Just fuse what I tell you to fuse", using \p DFPatterns to directly select candidates:
153 * \code
154 * PrimitivePartitionRule
155 * OnlyValidPartitionRule
156 * UnionPartitionRule
157 * DFPatternPartitionRule(pattern1)
158 * :
159 * DFPatternPartitionRule(patternn)
160 * \endcode
161 */
162class PartitionRuleNode : public Object {
163 public:
164 /*!
165 * \brief A unique (over all rules for the same target) name for the rule. Rule names are
166 * combined and captured with \p PartitionCandidate rule names for debuggability and
167 * explainability. Some rules will copy the rule name into function attributes.
168 *
169 */
170 String rule_name_;
171
172 void VisitAttrs(AttrVisitor* v);
173
174 /*!
175 * \brief Returns all the possible candidate partitions according to this rule for the overall
176 * expression corresponding to \p dataflow_graph. The candidates will generally have unknown
177 * target and cost: the target will be filled in by the \p PartitionSpec, while the cost will
178 * be filled in lazily.
179 */
180 virtual std::vector<CandidatePartition> AllCandidates(const DataflowGraph& dataflow_graph,
181 const PartitionSpec& spec) const;
182
183 std::string ToString() const;
184 Doc ToDoc() const;
185
186 protected:
187 virtual void AppendBodyItems(std::vector<Doc>* body_items) const;
188
189 public:
190 static constexpr const char* _type_key = "relay.collage.PartitionRule";
191 static constexpr const uint32_t _type_child_slots = 10;
192 TVM_DECLARE_BASE_OBJECT_INFO(PartitionRuleNode, Object);
193};
194
195class PartitionRule : public ObjectRef {
196 public:
197 explicit PartitionRule(String rule_name);
198
199 TVM_DEFINE_OBJECT_REF_METHODS(PartitionRule, ObjectRef, PartitionRuleNode);
200};
201
202/*!
203 * \brief Partition rule which fires on all sub-expressions matching a dataflow-pattern and pattern
204 * predicate. It is valid for matching candidates to overlap.
205 */
206class DFPatternPartitionRuleNode : public PartitionRuleNode {
207 public:
208 /*!
209 * \brief Relay pattern.
210 */
211 DFPattern pattern_;
212
213 /*!
214 * \brief Predicate on matched sub-expression to decide if partition rule should fire.
215 */
216 TPatternPredicate predicate_;
217
218 void VisitAttrs(AttrVisitor* v);
219
220 std::vector<CandidatePartition> AllCandidates(const DataflowGraph& dataflow_graph,
221 const PartitionSpec& spec) const override;
222
223 void AppendBodyItems(std::vector<Doc>* body_items) const override;
224
225 static constexpr const char* _type_key = "relay.collage.DFPatternPartitionRule";
226 TVM_DECLARE_FINAL_OBJECT_INFO(DFPatternPartitionRuleNode, PartitionRuleNode);
227};
228
229class DFPatternPartitionRule : public PartitionRule {
230 public:
231 DFPatternPartitionRule(String rule_name, DFPattern pattern,
232 TPatternPredicate predicate = DefaultPatternPredicate);
233
234 TVM_DEFINE_OBJECT_REF_METHODS(DFPatternPartitionRule, PartitionRule, DFPatternPartitionRuleNode);
235};
236
237/*!
238 * \brief Partition rule which wraps candidates within a function with the "Composite" attribute
239 * bound to the given rule name.
240 *
241 * This is the standard way by which operators or operator groups are tagged as being supported
242 * by a particular externally provided function. It is up to the BYOC lowering function to
243 * recognize the "Composite" name and emit the appropriate code or call.
244 */
245class CompositePartitionRuleNode : public PartitionRuleNode {
246 public:
247 /*! \brief The sub-partition rule. */
248 PartitionRule sub_rule_;
249
250 void VisitAttrs(AttrVisitor* v);
251
252 std::vector<CandidatePartition> AllCandidates(const DataflowGraph& dataflow_graph,
253 const PartitionSpec& spec) const override;
254
255 void AppendBodyItems(std::vector<Doc>* body_items) const override;
256
257 static constexpr const char* _type_key = "relay.collage.CompositePartitionRule";
258 TVM_DECLARE_FINAL_OBJECT_INFO(CompositePartitionRuleNode, PartitionRuleNode);
259};
260
261class CompositePartitionRule : public PartitionRule {
262 public:
263 CompositePartitionRule(String rule_name, PartitionRule sub_rule);
264
265 TVM_DEFINE_OBJECT_REF_METHODS(CompositePartitionRule, PartitionRule, CompositePartitionRuleNode);
266};
267
268/*!
269 * \brief Partition rule which wraps candidates within a function with the "Primitive" attribute
270 * bound to 1. If the partition spec target(s) have the "compiler" attribute then that name is
271 * also added to the function as a "Compiler" attribute.
272 *
273 * This is the standard way by which sub-graphs are marked as being in a 'partition' who's
274 * compilation will be managed by an external BYOC toolchain. It can also be used to mark
275 * sub-graphs for lowering to a single kernel by the built-in TVM lowering machinery.
276 */
277class PrimitivePartitionRuleNode : public PartitionRuleNode {
278 public:
279 /*! \brief The sub-partition rule. */
280 PartitionRule sub_rule_;
281
282 void VisitAttrs(AttrVisitor* v);
283
284 std::vector<CandidatePartition> AllCandidates(const DataflowGraph& dataflow_graph,
285 const PartitionSpec& spec) const override;
286
287 void AppendBodyItems(std::vector<Doc>* body_items) const override;
288
289 static constexpr const char* _type_key = "relay.collage.PrimitivePartitionRule";
290 TVM_DECLARE_FINAL_OBJECT_INFO(PrimitivePartitionRuleNode, PartitionRuleNode);
291};
292
293class PrimitivePartitionRule : public PartitionRule {
294 public:
295 PrimitivePartitionRule(String rule_name, PartitionRule sub_rule);
296
297 TVM_DEFINE_OBJECT_REF_METHODS(PrimitivePartitionRule, PartitionRule, PrimitivePartitionRuleNode);
298};
299
300/*!
301 * \brief Partition rule which simply unions all matches from all sub-partition rules.
302 *
303 * This can be used to combine the results of a set of, eg, DFPatternPartitionRules.
304 */
305class UnionPartitionRuleNode : public PartitionRuleNode {
306 public:
307 Array<PartitionRule> sub_rules_;
308
309 void VisitAttrs(AttrVisitor* v);
310
311 std::vector<CandidatePartition> AllCandidates(const DataflowGraph& dataflow_graph,
312 const PartitionSpec& spec) const override;
313
314 void AppendBodyItems(std::vector<Doc>* body_items) const override;
315
316 static constexpr const char* _type_key = "relay.collage.UnionPartitionRule";
317 TVM_DECLARE_FINAL_OBJECT_INFO(UnionPartitionRuleNode, PartitionRuleNode);
318};
319
320class UnionPartitionRule : public PartitionRule {
321 public:
322 UnionPartitionRule(String rule_name, Array<PartitionRule> sub_rules);
323
324 TVM_DEFINE_OBJECT_REF_METHODS(UnionPartitionRule, PartitionRule, UnionPartitionRuleNode)
325};
326
327/*
328 *! \brief Partition rule which places calls to Relay operators with a "TOpPattern" attribute of
329 * \p kOutEWiseFusable or less in their own singleton sub-graph. No other Relay sub-expressions
330 * (such as tuples or tuple projection) are selected, and it is up to outer partition rules to
331 * account for them.
332 */
333class OpCallByKindPartitionRuleNode : public PartitionRuleNode {
334 public:
335 void VisitAttrs(AttrVisitor* v);
336
337 std::vector<CandidatePartition> AllCandidates(const DataflowGraph& dataflow_graph,
338 const PartitionSpec& spec) const override;
339
340 void AppendBodyItems(std::vector<Doc>* body_items) const override;
341
342 static constexpr const char* _type_key = "relay.collage.OpCallByKindPartitionRule";
343 TVM_DECLARE_FINAL_OBJECT_INFO(OpCallByKindPartitionRuleNode, PartitionRuleNode);
344};
345
346class OpCallByKindPartitionRule : public PartitionRule {
347 public:
348 explicit OpCallByKindPartitionRule(String rule_name);
349
350 TVM_DEFINE_OBJECT_REF_METHODS(OpCallByKindPartitionRule, PartitionRule,
351 OpCallByKindPartitionRuleNode);
352};
353
354/*!
355 * \brief Partition rule which combines sub-graphs to exploit optimizations commonly available in
356 * backends (including the TVM lowering backend). Those optimization rules are in turn described by
357 * one or more primitive \p CombinerRules.
358 *
359 * For TVM these primitive combiner rules are guided by the \p OpPatternKind associated with every
360 * sub-graph. That in turn is the maximum of the kind of each expression node in the sub-graph,
361 * using the rules:
362 * - Constants are \p kElemwise.
363 * - A call to a Relay operator has the kind of its callee.
364 * - Tuple construction and projection are injective provided all tuple fields are of tensor type.
365 * - All other sub-expressions are opaque.
366 *
367 * The available \p OpPatternKinds (and our abbreviations for them) are:
368 * - E: kElemWise, eg nn.relu
369 * - B: kBroadcast, eg add
370 * - I: kInjective, eg concatenate
371 * - R: kCommReduce, eg sum
372 * - A: kOutEWiseFusable, eg nn.conv2d (often called 'anchor nodes', hence the A abbreviation)
373 * - O: kOpaque, everything else
374 * (The kTuple kind is not used by this machinery.)
375 *
376 * Kinds are ordered as above from least- to most-constraining w.r.t. possible partition
377 * opportunities. When we write a kind abbreviation below we intend it to mean that kind *or less*.
378 * And when when write 'kl -> kr' we mean it to match a sub-expression of kind kr or less who's
379 * dataflow inputs are all of kind kl or less.
380 *
381 * We can then mimic the classic \p FuseOps TVM Pass with the following more primitive combiner
382 * rules:
383 * - Sub-groups cannot have taps. In the classic \p FuseOps pass taps are avoided by construction
384 * by always considering all node->dominator paths. Here we naively allow taps on all candidates,
385 * but reject them using SubGraph::IsValid with a SubGraphConfig with allow_taps = false.
386 * - Combine A -> B
387 * - Combine B -> R
388 * - Combine I -> I
389 * - Combine I -> tuple -> I. That is, if an I sub-graph has a tuple as input, and at least one
390 * tuple field can be provided by an I sub-graph exit, then both the tuple and all such fields
391 * may be joined.
392 gt*
393 * Note that \p FuseOps only considers the largest possible sub-graphs. However this partition rule
394 * considers all possibilities so as to 'make room' for other targets supplying other
395 * overlapping candidates.
396 *
397 * See combiner_rule.h for the more primitive combiner rules which implement the above.
398 */
399class CombinePartitionRuleNode : public PartitionRuleNode {
400 public:
401 /*! \brief The sub-rule supplying the initial set of candidates. */
402 PartitionRule sub_rule_;
403 /*! \brief The more primitive rules to use to combine the candidates found by the above rule. */
404 Array<CombinerRule> combiner_rules_;
405 /*! \brief Maximum max_depth for candidates. */
406 size_t max_depth_;
407
408 void VisitAttrs(AttrVisitor* v);
409
410 std::vector<CandidatePartition> AllCandidates(const DataflowGraph& dataflow_graph,
411 const PartitionSpec& spec) const override;
412
413 void AppendBodyItems(std::vector<Doc>* body_items) const override;
414
415 public:
416 static constexpr const char* _type_key = "relay.collage.CombinePartitionRule";
417 TVM_DECLARE_FINAL_OBJECT_INFO(CombinePartitionRuleNode, PartitionRuleNode);
418};
419
420class CombinePartitionRule : public PartitionRule {
421 public:
422 CombinePartitionRule(String rule_name, PartitionRule sub_rule, Array<CombinerRule> combiner_rules,
423 size_t max_depth_);
424
425 TVM_DEFINE_OBJECT_REF_METHODS(CombinePartitionRule, PartitionRule, CombinePartitionRuleNode);
426};
427
428/*!
429 * \brief Partition rules which keeps only candidates from the sub-rule whose sub-groups are valid
430 * w.r.t. the given \p SubGraphConfig.
431 */
432class OnlyValidPartitionRuleNode : public PartitionRuleNode {
433 public:
434 PartitionRule sub_rule_;
435 SubGraphConfig config_;
436
437 void VisitAttrs(AttrVisitor* v);
438
439 std::vector<CandidatePartition> AllCandidates(const DataflowGraph& dataflow_graph,
440 const PartitionSpec& spec) const override;
441
442 void AppendBodyItems(std::vector<Doc>* body_items) const override;
443
444 public:
445 static constexpr const char* _type_key = "relay.collage.OnlyValidPartitionRule";
446 TVM_DECLARE_FINAL_OBJECT_INFO(OnlyValidPartitionRuleNode, PartitionRuleNode);
447};
448
449class OnlyValidPartitionRule : public PartitionRule {
450 public:
451 OnlyValidPartitionRule(String rule_name, PartitionRule sub_rule, const SubGraphConfig& config);
452
453 TVM_DEFINE_OBJECT_REF_METHODS(OnlyValidPartitionRule, PartitionRule, OnlyValidPartitionRuleNode);
454};
455
456/*!
457 * \brief Partition rule which selects nodes which can be 'left behind' to be executed by the host
458 * (eg on the VM). This includes most of the 'interstitial' Relay constructs, such a let bindings,
459 * operators on references, calls to non-operator functions, and so on. It can also include the
460 * construction of and projection from tuples which may not be supported within a partition.
461 */
462class HostPartitionRuleNode : public PartitionRuleNode {
463 public:
464 void VisitAttrs(AttrVisitor* v);
465
466 std::vector<CandidatePartition> AllCandidates(const DataflowGraph& dataflow_graph,
467 const PartitionSpec& spec) const override;
468
469 void AppendBodyItems(std::vector<Doc>* body_items) const override;
470
471 public:
472 static constexpr const char* _type_key = "relay.collage.HostPartitionRule";
473 TVM_DECLARE_FINAL_OBJECT_INFO(HostPartitionRuleNode, PartitionRuleNode);
474};
475
476class HostPartitionRule : public PartitionRule {
477 public:
478 explicit HostPartitionRule(String rule_name);
479
480 TVM_DEFINE_OBJECT_REF_METHODS(HostPartitionRule, PartitionRule, HostPartitionRuleNode);
481};
482
483} // namespace collage
484} // namespace relay
485} // namespace tvm
486
487#endif // TVM_RELAY_COLLAGE_PARTITION_RULE_H_
488