1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/collage/combiner_rule.h |
22 | * \brief Helpers for the \p CombinePartitionRule |
23 | */ |
24 | |
25 | #ifndef TVM_RELAY_COLLAGE_COMBINER_RULE_H_ |
26 | #define TVM_RELAY_COLLAGE_COMBINER_RULE_H_ |
27 | |
28 | #include <tvm/relay/dataflow_pattern.h> |
29 | #include <tvm/relay/expr.h> |
30 | |
31 | #include <string> |
32 | |
33 | #include "./candidate_partition.h" |
34 | #include "./candidate_set.h" |
35 | #include "./sub_graph.h" |
36 | |
37 | namespace tvm { |
38 | namespace relay { |
39 | namespace collage { |
40 | |
41 | /*! |
42 | * \brief Base class for all 'simple' combiner rules. |
43 | * |
44 | * Given \p upstream and \p downstream candidates which touch, a simple combiner rule returns |
45 | * true if their union should also be considered a candidate. |
46 | */ |
47 | class SimpleCombinerRuleNode : public Object { |
48 | public: |
49 | String rule_name_; |
50 | |
51 | void VisitAttrs(AttrVisitor* v); |
52 | |
53 | virtual bool Fires(const DataflowGraph& dataflow_graph, const CandidatePartition& upstream, |
54 | const CandidatePartition& downstream) const; |
55 | |
56 | virtual std::string ToString() const; |
57 | |
58 | static constexpr const char* _type_key = "relay.collage.SimpleCombinerRule" ; |
59 | static constexpr const uint32_t _type_child_slots = 1; |
60 | TVM_DECLARE_BASE_OBJECT_INFO(SimpleCombinerRuleNode, Object); |
61 | }; |
62 | |
63 | class SimpleCombinerRule : public ObjectRef { |
64 | public: |
65 | explicit SimpleCombinerRule(String rule_name); |
66 | |
67 | TVM_DEFINE_OBJECT_REF_METHODS(SimpleCombinerRule, ObjectRef, SimpleCombinerRuleNode); |
68 | }; |
69 | |
70 | /*! |
71 | * \brief A simple combiner rule which fires if the \p upstream and \p downstream candidates have |
72 | * the given \p upstream_kind and \p downstream_kind (or less) respectively. |
73 | */ |
74 | class ByKindSimpleCombinerRuleNode : public SimpleCombinerRuleNode { |
75 | public: |
76 | OpPatternKind upstream_kind_; |
77 | OpPatternKind downstream_kind_; |
78 | |
79 | void VisitAttrs(AttrVisitor* v); |
80 | |
81 | bool Fires(const DataflowGraph& dataflow_graph, const CandidatePartition& upstream, |
82 | const CandidatePartition& downstream) const override; |
83 | std::string ToString() const override; |
84 | |
85 | static constexpr const char* _type_key = "relay.collage.ByKindSimpleCombinerRule" ; |
86 | TVM_DECLARE_FINAL_OBJECT_INFO(ByKindSimpleCombinerRuleNode, SimpleCombinerRuleNode); |
87 | }; |
88 | |
89 | class ByKindSimpleCombinerRule : public SimpleCombinerRule { |
90 | public: |
91 | ByKindSimpleCombinerRule(OpPatternKind upstream_kind, OpPatternKind downstream_kind); |
92 | |
93 | TVM_DEFINE_OBJECT_REF_METHODS(ByKindSimpleCombinerRule, SimpleCombinerRule, |
94 | ByKindSimpleCombinerRuleNode); |
95 | }; |
96 | |
97 | /*! \brief Context required by CombineRuleNode::AppendAllResultsContext. */ |
98 | struct AppendAllResultsContext { |
99 | AppendAllResultsContext(const DataflowGraph* dataflow_graph, size_t max_depth, |
100 | CandidateSet* candidate_set) |
101 | : dataflow_graph(dataflow_graph), max_depth(max_depth), candidate_set(candidate_set) {} |
102 | |
103 | const DataflowGraph* dataflow_graph; |
104 | size_t max_depth; |
105 | CandidateSet* candidate_set; |
106 | }; |
107 | |
108 | /*! |
109 | * \brief Base class for all 'combiner' rules. |
110 | * |
111 | * Given the current candidate set, a combiner rule looks for opportunities to form larger |
112 | * candidates, optionally removing existing candidates in the process. |
113 | */ |
114 | class CombinerRuleNode : public Object { |
115 | public: |
116 | String rule_name_; |
117 | |
118 | void VisitAttrs(AttrVisitor* v); |
119 | |
120 | virtual void AppendAllResults(AppendAllResultsContext* ctxt) const; |
121 | virtual std::string ToString() const; |
122 | |
123 | static constexpr const char* _type_key = "relay.collage.CombinerRule" ; |
124 | static constexpr const uint32_t _type_child_slots = 4; |
125 | TVM_DECLARE_BASE_OBJECT_INFO(CombinerRuleNode, Object); |
126 | }; |
127 | |
128 | class CombinerRule : public ObjectRef { |
129 | public: |
130 | explicit CombinerRule(String rule_name); |
131 | |
132 | TVM_DEFINE_OBJECT_REF_METHODS(CombinerRule, ObjectRef, CombinerRuleNode); |
133 | }; |
134 | |
135 | /*! |
136 | * \brief A combiner rule which runs one or more simple combiner rules over the current |
137 | * touching candidates. |
138 | */ |
139 | class AllSimpleCombinerRuleNode : public CombinerRuleNode { |
140 | public: |
141 | Array<SimpleCombinerRule> simple_rules_; |
142 | |
143 | void VisitAttrs(AttrVisitor* v); |
144 | |
145 | void AppendAllResults(AppendAllResultsContext* ctxt) const override; |
146 | std::string ToString() const override; |
147 | |
148 | static constexpr const char* _type_key = "relay.collage.AllSimpleCombinerRule" ; |
149 | TVM_DECLARE_FINAL_OBJECT_INFO(AllSimpleCombinerRuleNode, CombinerRuleNode); |
150 | }; |
151 | |
152 | class AllSimpleCombinerRule : public CombinerRule { |
153 | public: |
154 | AllSimpleCombinerRule(String rule_name, Array<SimpleCombinerRule> simple_rules); |
155 | |
156 | TVM_DEFINE_OBJECT_REF_METHODS(AllSimpleCombinerRule, CombinerRule, AllSimpleCombinerRuleNode); |
157 | }; |
158 | |
159 | /*! |
160 | * \brief A combiner rule which combines injective sub-groups which appear inside tuples which are |
161 | * themselves inputs to injective sub-groups. |
162 | */ |
163 | class TupleArgCombinerRuleNode : public CombinerRuleNode { |
164 | public: |
165 | void VisitAttrs(AttrVisitor* v); |
166 | |
167 | void AppendAllResults(AppendAllResultsContext* ctxt) const override; |
168 | std::string ToString() const override; |
169 | |
170 | static constexpr const char* _type_key = "relay.collage.TupleArgCombinerRule" ; |
171 | TVM_DECLARE_FINAL_OBJECT_INFO(TupleArgCombinerRuleNode, CombinerRuleNode); |
172 | }; |
173 | |
174 | class TupleArgCombinerRule : public CombinerRule { |
175 | public: |
176 | explicit TupleArgCombinerRule(String rule_name); |
177 | |
178 | TVM_DEFINE_OBJECT_REF_METHODS(TupleArgCombinerRule, CombinerRule, TupleArgCombinerRuleNode); |
179 | }; |
180 | |
181 | /*! |
182 | * \brief A combiner rule which combines tuple projection if it's an output of an injective |
183 | * group. |
184 | */ |
185 | class TupleProjCombinerRuleNode : public CombinerRuleNode { |
186 | public: |
187 | void VisitAttrs(AttrVisitor* v); |
188 | |
189 | void AppendAllResults(AppendAllResultsContext* ctxt) const override; |
190 | std::string ToString() const override; |
191 | |
192 | static constexpr const char* _type_key = "relay.collage.TupleProjCombinerRule" ; |
193 | TVM_DECLARE_FINAL_OBJECT_INFO(TupleProjCombinerRuleNode, CombinerRuleNode); |
194 | }; |
195 | |
196 | class TupleProjCombinerRule : public CombinerRule { |
197 | public: |
198 | explicit TupleProjCombinerRule(String rule_name); |
199 | |
200 | TVM_DEFINE_OBJECT_REF_METHODS(TupleProjCombinerRule, CombinerRule, TupleProjCombinerRuleNode); |
201 | }; |
202 | |
203 | /*! |
204 | * \brief A combiner rule which combines constants in argument positions to existing candidates. |
205 | * Note that scalars are always inlined, so this rule only combines tensor constant arguments. |
206 | */ |
207 | class ConstantCombinerRuleNode : public CombinerRuleNode { |
208 | public: |
209 | void VisitAttrs(AttrVisitor* v); |
210 | |
211 | void AppendAllResults(AppendAllResultsContext* ctxt) const override; |
212 | std::string ToString() const override; |
213 | |
214 | static constexpr const char* _type_key = "relay.collage.ConstantCombinerRule" ; |
215 | TVM_DECLARE_FINAL_OBJECT_INFO(ConstantCombinerRuleNode, CombinerRuleNode); |
216 | }; |
217 | |
218 | class ConstantCombinerRule : public CombinerRule { |
219 | public: |
220 | explicit ConstantCombinerRule(String rule_name); |
221 | |
222 | TVM_DEFINE_OBJECT_REF_METHODS(ConstantCombinerRule, CombinerRule, ConstantCombinerRuleNode); |
223 | }; |
224 | |
225 | } // namespace collage |
226 | } // namespace relay |
227 | } // namespace tvm |
228 | |
229 | #endif // TVM_RELAY_COLLAGE_COMBINER_RULE_H_ |
230 | |