1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/collage/combiner_rule.h
22 * \brief Helpers for the \p CombinePartitionRule
23 */
24
25#ifndef TVM_RELAY_COLLAGE_COMBINER_RULE_H_
26#define TVM_RELAY_COLLAGE_COMBINER_RULE_H_
27
28#include <tvm/relay/dataflow_pattern.h>
29#include <tvm/relay/expr.h>
30
31#include <string>
32
33#include "./candidate_partition.h"
34#include "./candidate_set.h"
35#include "./sub_graph.h"
36
37namespace tvm {
38namespace relay {
39namespace collage {
40
41/*!
42 * \brief Base class for all 'simple' combiner rules.
43 *
44 * Given \p upstream and \p downstream candidates which touch, a simple combiner rule returns
45 * true if their union should also be considered a candidate.
46 */
47class SimpleCombinerRuleNode : public Object {
48 public:
49 String rule_name_;
50
51 void VisitAttrs(AttrVisitor* v);
52
53 virtual bool Fires(const DataflowGraph& dataflow_graph, const CandidatePartition& upstream,
54 const CandidatePartition& downstream) const;
55
56 virtual std::string ToString() const;
57
58 static constexpr const char* _type_key = "relay.collage.SimpleCombinerRule";
59 static constexpr const uint32_t _type_child_slots = 1;
60 TVM_DECLARE_BASE_OBJECT_INFO(SimpleCombinerRuleNode, Object);
61};
62
63class SimpleCombinerRule : public ObjectRef {
64 public:
65 explicit SimpleCombinerRule(String rule_name);
66
67 TVM_DEFINE_OBJECT_REF_METHODS(SimpleCombinerRule, ObjectRef, SimpleCombinerRuleNode);
68};
69
70/*!
71 * \brief A simple combiner rule which fires if the \p upstream and \p downstream candidates have
72 * the given \p upstream_kind and \p downstream_kind (or less) respectively.
73 */
74class ByKindSimpleCombinerRuleNode : public SimpleCombinerRuleNode {
75 public:
76 OpPatternKind upstream_kind_;
77 OpPatternKind downstream_kind_;
78
79 void VisitAttrs(AttrVisitor* v);
80
81 bool Fires(const DataflowGraph& dataflow_graph, const CandidatePartition& upstream,
82 const CandidatePartition& downstream) const override;
83 std::string ToString() const override;
84
85 static constexpr const char* _type_key = "relay.collage.ByKindSimpleCombinerRule";
86 TVM_DECLARE_FINAL_OBJECT_INFO(ByKindSimpleCombinerRuleNode, SimpleCombinerRuleNode);
87};
88
89class ByKindSimpleCombinerRule : public SimpleCombinerRule {
90 public:
91 ByKindSimpleCombinerRule(OpPatternKind upstream_kind, OpPatternKind downstream_kind);
92
93 TVM_DEFINE_OBJECT_REF_METHODS(ByKindSimpleCombinerRule, SimpleCombinerRule,
94 ByKindSimpleCombinerRuleNode);
95};
96
97/*! \brief Context required by CombineRuleNode::AppendAllResultsContext. */
98struct AppendAllResultsContext {
99 AppendAllResultsContext(const DataflowGraph* dataflow_graph, size_t max_depth,
100 CandidateSet* candidate_set)
101 : dataflow_graph(dataflow_graph), max_depth(max_depth), candidate_set(candidate_set) {}
102
103 const DataflowGraph* dataflow_graph;
104 size_t max_depth;
105 CandidateSet* candidate_set;
106};
107
108/*!
109 * \brief Base class for all 'combiner' rules.
110 *
111 * Given the current candidate set, a combiner rule looks for opportunities to form larger
112 * candidates, optionally removing existing candidates in the process.
113 */
114class CombinerRuleNode : public Object {
115 public:
116 String rule_name_;
117
118 void VisitAttrs(AttrVisitor* v);
119
120 virtual void AppendAllResults(AppendAllResultsContext* ctxt) const;
121 virtual std::string ToString() const;
122
123 static constexpr const char* _type_key = "relay.collage.CombinerRule";
124 static constexpr const uint32_t _type_child_slots = 4;
125 TVM_DECLARE_BASE_OBJECT_INFO(CombinerRuleNode, Object);
126};
127
128class CombinerRule : public ObjectRef {
129 public:
130 explicit CombinerRule(String rule_name);
131
132 TVM_DEFINE_OBJECT_REF_METHODS(CombinerRule, ObjectRef, CombinerRuleNode);
133};
134
135/*!
136 * \brief A combiner rule which runs one or more simple combiner rules over the current
137 * touching candidates.
138 */
139class AllSimpleCombinerRuleNode : public CombinerRuleNode {
140 public:
141 Array<SimpleCombinerRule> simple_rules_;
142
143 void VisitAttrs(AttrVisitor* v);
144
145 void AppendAllResults(AppendAllResultsContext* ctxt) const override;
146 std::string ToString() const override;
147
148 static constexpr const char* _type_key = "relay.collage.AllSimpleCombinerRule";
149 TVM_DECLARE_FINAL_OBJECT_INFO(AllSimpleCombinerRuleNode, CombinerRuleNode);
150};
151
152class AllSimpleCombinerRule : public CombinerRule {
153 public:
154 AllSimpleCombinerRule(String rule_name, Array<SimpleCombinerRule> simple_rules);
155
156 TVM_DEFINE_OBJECT_REF_METHODS(AllSimpleCombinerRule, CombinerRule, AllSimpleCombinerRuleNode);
157};
158
159/*!
160 * \brief A combiner rule which combines injective sub-groups which appear inside tuples which are
161 * themselves inputs to injective sub-groups.
162 */
163class TupleArgCombinerRuleNode : public CombinerRuleNode {
164 public:
165 void VisitAttrs(AttrVisitor* v);
166
167 void AppendAllResults(AppendAllResultsContext* ctxt) const override;
168 std::string ToString() const override;
169
170 static constexpr const char* _type_key = "relay.collage.TupleArgCombinerRule";
171 TVM_DECLARE_FINAL_OBJECT_INFO(TupleArgCombinerRuleNode, CombinerRuleNode);
172};
173
174class TupleArgCombinerRule : public CombinerRule {
175 public:
176 explicit TupleArgCombinerRule(String rule_name);
177
178 TVM_DEFINE_OBJECT_REF_METHODS(TupleArgCombinerRule, CombinerRule, TupleArgCombinerRuleNode);
179};
180
181/*!
182 * \brief A combiner rule which combines tuple projection if it's an output of an injective
183 * group.
184 */
185class TupleProjCombinerRuleNode : public CombinerRuleNode {
186 public:
187 void VisitAttrs(AttrVisitor* v);
188
189 void AppendAllResults(AppendAllResultsContext* ctxt) const override;
190 std::string ToString() const override;
191
192 static constexpr const char* _type_key = "relay.collage.TupleProjCombinerRule";
193 TVM_DECLARE_FINAL_OBJECT_INFO(TupleProjCombinerRuleNode, CombinerRuleNode);
194};
195
196class TupleProjCombinerRule : public CombinerRule {
197 public:
198 explicit TupleProjCombinerRule(String rule_name);
199
200 TVM_DEFINE_OBJECT_REF_METHODS(TupleProjCombinerRule, CombinerRule, TupleProjCombinerRuleNode);
201};
202
203/*!
204 * \brief A combiner rule which combines constants in argument positions to existing candidates.
205 * Note that scalars are always inlined, so this rule only combines tensor constant arguments.
206 */
207class ConstantCombinerRuleNode : public CombinerRuleNode {
208 public:
209 void VisitAttrs(AttrVisitor* v);
210
211 void AppendAllResults(AppendAllResultsContext* ctxt) const override;
212 std::string ToString() const override;
213
214 static constexpr const char* _type_key = "relay.collage.ConstantCombinerRule";
215 TVM_DECLARE_FINAL_OBJECT_INFO(ConstantCombinerRuleNode, CombinerRuleNode);
216};
217
218class ConstantCombinerRule : public CombinerRule {
219 public:
220 explicit ConstantCombinerRule(String rule_name);
221
222 TVM_DEFINE_OBJECT_REF_METHODS(ConstantCombinerRule, CombinerRule, ConstantCombinerRuleNode);
223};
224
225} // namespace collage
226} // namespace relay
227} // namespace tvm
228
229#endif // TVM_RELAY_COLLAGE_COMBINER_RULE_H_
230