1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/relay/adt.h |
22 | * \brief Algebraic data types for Relay |
23 | */ |
24 | #ifndef TVM_RELAY_ADT_H_ |
25 | #define TVM_RELAY_ADT_H_ |
26 | |
27 | #include <tvm/ir/adt.h> |
28 | #include <tvm/ir/attrs.h> |
29 | #include <tvm/relay/base.h> |
30 | #include <tvm/relay/expr.h> |
31 | #include <tvm/relay/type.h> |
32 | |
33 | #include <functional> |
34 | #include <string> |
35 | #include <utility> |
36 | |
37 | namespace tvm { |
38 | namespace relay { |
39 | |
40 | using Constructor = tvm::Constructor; |
41 | using ConstructorNode = tvm::ConstructorNode; |
42 | |
43 | using TypeData = tvm::TypeData; |
44 | using TypeDataNode = tvm::TypeDataNode; |
45 | |
46 | /*! \brief Base type for declaring relay pattern. */ |
47 | class PatternNode : public RelayNode { |
48 | public: |
49 | static constexpr const char* _type_key = "relay.Pattern" ; |
50 | static constexpr const bool _type_has_method_sequal_reduce = true; |
51 | static constexpr const bool _type_has_method_shash_reduce = true; |
52 | TVM_DECLARE_BASE_OBJECT_INFO(PatternNode, Object); |
53 | }; |
54 | |
55 | /*! |
56 | * \brief Pattern is the base type for an ADT match pattern in Relay. |
57 | * |
58 | * Given an ADT value, a pattern might accept it and bind the pattern variable to some value |
59 | * (typically a subnode of the input or the input). Otherwise, the pattern rejects the value. |
60 | * |
61 | * ADT pattern matching thus takes a list of values and binds to the first that accepts the value. |
62 | */ |
63 | class Pattern : public ObjectRef { |
64 | public: |
65 | Pattern() {} |
66 | explicit Pattern(ObjectPtr<tvm::Object> p) : ObjectRef(p) {} |
67 | |
68 | using ContainerType = PatternNode; |
69 | }; |
70 | |
71 | /*! \brief A wildcard pattern: Accepts all input and binds nothing. */ |
72 | class PatternWildcard; |
73 | /*! \brief PatternWildcard container node */ |
74 | class PatternWildcardNode : public PatternNode { |
75 | public: |
76 | void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("span" , &span); } |
77 | |
78 | bool SEqualReduce(const PatternNode* other, SEqualReducer equal) const { return true; } |
79 | |
80 | void SHashReduce(SHashReducer hash_reduce) const {} |
81 | |
82 | static constexpr const char* _type_key = "relay.PatternWildcard" ; |
83 | TVM_DECLARE_FINAL_OBJECT_INFO(PatternWildcardNode, PatternNode); |
84 | }; |
85 | |
86 | class PatternWildcard : public Pattern { |
87 | public: |
88 | /* \brief Overload the default constructors. */ |
89 | TVM_DLL PatternWildcard(); |
90 | explicit PatternWildcard(ObjectPtr<Object> n) : Pattern(n) {} |
91 | /* \brief Copy constructor. */ |
92 | PatternWildcard(const PatternWildcard& pat) : PatternWildcard(pat.data_) {} |
93 | /* \brief Move constructor. */ |
94 | PatternWildcard(PatternWildcard&& pat) : PatternWildcard(std::move(pat.data_)) {} |
95 | /* \brief Copy assignment. */ |
96 | PatternWildcard& operator=(const PatternWildcard& other) { |
97 | (*this).data_ = other.data_; |
98 | return *this; |
99 | } |
100 | /* \brief Move assignment. */ |
101 | PatternWildcard& operator=(PatternWildcard&& other) { |
102 | (*this).data_ = std::move(other.data_); |
103 | return *this; |
104 | } |
105 | |
106 | const PatternWildcardNode* operator->() const { |
107 | return static_cast<const PatternWildcardNode*>(get()); |
108 | } |
109 | |
110 | using ContainerType = PatternWildcardNode; |
111 | }; |
112 | |
113 | /*! \brief A var pattern. Accept all input and bind to a var. */ |
114 | class PatternVar; |
115 | /*! \brief PatternVar container node */ |
116 | class PatternVarNode : public PatternNode { |
117 | public: |
118 | /*! \brief Variable that stores the matched value. */ |
119 | tvm::relay::Var var; |
120 | |
121 | void VisitAttrs(tvm::AttrVisitor* v) { |
122 | v->Visit("var" , &var); |
123 | v->Visit("span" , &span); |
124 | } |
125 | |
126 | bool SEqualReduce(const PatternVarNode* other, SEqualReducer equal) const { |
127 | return equal.DefEqual(var, other->var); |
128 | } |
129 | |
130 | void SHashReduce(SHashReducer hash_reduce) const { hash_reduce.DefHash(var); } |
131 | |
132 | static constexpr const char* _type_key = "relay.PatternVar" ; |
133 | TVM_DECLARE_FINAL_OBJECT_INFO(PatternVarNode, PatternNode); |
134 | }; |
135 | |
136 | class PatternVar : public Pattern { |
137 | public: |
138 | /*! |
139 | * \brief Constructor |
140 | * \param var The var to construct a pattern |
141 | */ |
142 | TVM_DLL explicit PatternVar(tvm::relay::Var var); |
143 | |
144 | TVM_DEFINE_OBJECT_REF_METHODS(PatternVar, Pattern, PatternVarNode); |
145 | }; |
146 | |
147 | /*! \brief A constructor pattern. Matches a value with the given constructor, binds recursively. */ |
148 | class PatternConstructor; |
149 | /*! \brief PatternVar container node */ |
150 | class PatternConstructorNode : public PatternNode { |
151 | public: |
152 | /*! Constructor matched by the pattern. */ |
153 | Constructor constructor; |
154 | /*! Sub-patterns to match against each input to the constructor. */ |
155 | tvm::Array<Pattern> patterns; |
156 | |
157 | void VisitAttrs(tvm::AttrVisitor* v) { |
158 | v->Visit("constructor" , &constructor); |
159 | v->Visit("patterns" , &patterns); |
160 | v->Visit("span" , &span); |
161 | } |
162 | |
163 | bool SEqualReduce(const PatternConstructorNode* other, SEqualReducer equal) const { |
164 | return equal(constructor, other->constructor) && equal(patterns, other->patterns); |
165 | } |
166 | |
167 | void SHashReduce(SHashReducer hash_reduce) const { |
168 | hash_reduce(constructor); |
169 | hash_reduce(patterns); |
170 | } |
171 | |
172 | static constexpr const char* _type_key = "relay.PatternConstructor" ; |
173 | TVM_DECLARE_FINAL_OBJECT_INFO(PatternConstructorNode, PatternNode); |
174 | }; |
175 | |
176 | class PatternConstructor : public Pattern { |
177 | public: |
178 | /*! |
179 | * \brief Constructor |
180 | * \param constructor The constructor of a pattern |
181 | * \param patterns The sub-patterns for matching |
182 | */ |
183 | TVM_DLL PatternConstructor(Constructor constructor, tvm::Array<Pattern> patterns); |
184 | |
185 | TVM_DEFINE_OBJECT_REF_METHODS(PatternConstructor, Pattern, PatternConstructorNode); |
186 | }; |
187 | |
188 | /*! \brief A tuple pattern. Matches a tuple, binds recursively. */ |
189 | class PatternTuple; |
190 | /*! \brief PatternVar container node */ |
191 | class PatternTupleNode : public PatternNode { |
192 | public: |
193 | /* TODO(@jroesch): rename to field_pats */ |
194 | /*! Sub-patterns to match against each value of the tuple. */ |
195 | tvm::Array<Pattern> patterns; |
196 | |
197 | void VisitAttrs(tvm::AttrVisitor* v) { |
198 | v->Visit("patterns" , &patterns); |
199 | v->Visit("span" , &span); |
200 | } |
201 | |
202 | bool SEqualReduce(const PatternTupleNode* other, SEqualReducer equal) const { |
203 | return equal(patterns, other->patterns); |
204 | } |
205 | |
206 | void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(patterns); } |
207 | |
208 | static constexpr const char* _type_key = "relay.PatternTuple" ; |
209 | TVM_DECLARE_FINAL_OBJECT_INFO(PatternTupleNode, PatternNode); |
210 | }; |
211 | |
212 | class PatternTuple : public Pattern { |
213 | public: |
214 | /*! |
215 | * \brief Constructor |
216 | * \param patterns The sub-patterns to match against each value of the tuple |
217 | */ |
218 | TVM_DLL explicit PatternTuple(tvm::Array<Pattern> patterns); |
219 | |
220 | TVM_DEFINE_OBJECT_REF_METHODS(PatternTuple, Pattern, PatternTupleNode); |
221 | }; |
222 | |
223 | /*! \brief A clause in a match expression. */ |
224 | class Clause; |
225 | /*! \brief Clause container node. */ |
226 | class ClauseNode : public Object { |
227 | public: |
228 | /*! \brief The pattern the clause matches. */ |
229 | Pattern lhs; |
230 | /*! \brief The resulting value. */ |
231 | Expr rhs; |
232 | |
233 | void VisitAttrs(tvm::AttrVisitor* v) { |
234 | v->Visit("lhs" , &lhs); |
235 | v->Visit("rhs" , &rhs); |
236 | } |
237 | |
238 | bool SEqualReduce(const ClauseNode* other, SEqualReducer equal) const { |
239 | return equal(lhs, other->lhs) && equal(rhs, other->rhs); |
240 | } |
241 | |
242 | void SHashReduce(SHashReducer hash_reduce) const { |
243 | hash_reduce(lhs); |
244 | hash_reduce(rhs); |
245 | } |
246 | |
247 | static constexpr const char* _type_key = "relay.Clause" ; |
248 | static constexpr const bool _type_has_method_sequal_reduce = true; |
249 | static constexpr const bool _type_has_method_shash_reduce = true; |
250 | TVM_DECLARE_FINAL_OBJECT_INFO(ClauseNode, Object); |
251 | }; |
252 | |
253 | class Clause : public ObjectRef { |
254 | public: |
255 | /*! |
256 | * \brief Constructor |
257 | * \param lhs The pattern matched by the clause. |
258 | * \param rhs The resulting value |
259 | */ |
260 | TVM_DLL explicit Clause(Pattern lhs, Expr rhs); |
261 | |
262 | TVM_DEFINE_OBJECT_REF_METHODS(Clause, ObjectRef, ClauseNode); |
263 | TVM_DEFINE_OBJECT_REF_COW_METHOD(ClauseNode); |
264 | }; |
265 | |
266 | /*! |
267 | * \brief Returns \p clause with the given properties. A null property denotes 'no change'. |
268 | * Returns \p clause if all properties are unchanged. Otherwise, returns a copy with the new |
269 | * fields. |
270 | */ |
271 | Clause WithFields(Clause clause, Optional<Pattern> opt_lhs = Optional<Pattern>(), |
272 | Optional<Expr> opt_rhs = Optional<Expr>()); |
273 | |
274 | /*! \brief ADT pattern matching exression. */ |
275 | class Match; |
276 | /*! \brief Match container node. */ |
277 | class MatchNode : public ExprNode { |
278 | public: |
279 | /*! \brief The input being deconstructed. */ |
280 | Expr data; |
281 | |
282 | /*! \brief The match node clauses. */ |
283 | tvm::Array<Clause> clauses; |
284 | |
285 | /*! \brief Should this match be complete (cover all cases)? |
286 | * If yes, the type checker will generate an error if there are any missing cases. |
287 | */ |
288 | bool complete; |
289 | |
290 | void VisitAttrs(tvm::AttrVisitor* v) { |
291 | v->Visit("data" , &data); |
292 | v->Visit("clauses" , &clauses); |
293 | v->Visit("complete" , &complete); |
294 | v->Visit("virtual_device_" , &virtual_device_); |
295 | v->Visit("span" , &span); |
296 | v->Visit("_checked_type_" , &checked_type_); |
297 | } |
298 | |
299 | bool SEqualReduce(const MatchNode* other, SEqualReducer equal) const { |
300 | equal->MarkGraphNode(); |
301 | return equal(data, other->data) && equal(clauses, other->clauses) && |
302 | equal(complete, other->complete); |
303 | } |
304 | |
305 | void SHashReduce(SHashReducer hash_reduce) const { |
306 | hash_reduce->MarkGraphNode(); |
307 | hash_reduce(data); |
308 | hash_reduce(clauses); |
309 | hash_reduce(complete); |
310 | } |
311 | |
312 | static constexpr const char* _type_key = "relay.Match" ; |
313 | TVM_DECLARE_FINAL_OBJECT_INFO(MatchNode, ExprNode); |
314 | }; |
315 | |
316 | class Match : public Expr { |
317 | public: |
318 | /*! |
319 | * \brief Constructor |
320 | * \param data the input being deconstructed. |
321 | * \param clauses The clauses for matching. |
322 | * \param complete Indicate if this match is complete. |
323 | * \param span The span of the expression. |
324 | */ |
325 | TVM_DLL Match(Expr data, tvm::Array<Clause> clauses, bool complete = true, Span span = Span()); |
326 | |
327 | TVM_DEFINE_OBJECT_REF_METHODS(Match, RelayExpr, MatchNode); |
328 | TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchNode); |
329 | }; |
330 | |
331 | /*! |
332 | * \brief Returns \p match with the given properties. A null property denotes 'no change'. |
333 | * Returns \p match if all properties are unchanged. Otherwise, returns a copy with the new |
334 | * fields. |
335 | */ |
336 | Match WithFields(Match match, Optional<Expr> opt_data = Optional<Expr>(), |
337 | Optional<Array<Clause>> opt_clauses = Optional<Array<Clause>>(), |
338 | Optional<Bool> opt_complete = Optional<Bool>(), |
339 | Optional<Span> opt_span = Optional<Span>()); |
340 | |
341 | } // namespace relay |
342 | } // namespace tvm |
343 | |
344 | #endif // TVM_RELAY_ADT_H_ |
345 | |