1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/relay/adt.h
22 * \brief Algebraic data types for Relay
23 */
24#ifndef TVM_RELAY_ADT_H_
25#define TVM_RELAY_ADT_H_
26
27#include <tvm/ir/adt.h>
28#include <tvm/ir/attrs.h>
29#include <tvm/relay/base.h>
30#include <tvm/relay/expr.h>
31#include <tvm/relay/type.h>
32
33#include <functional>
34#include <string>
35#include <utility>
36
37namespace tvm {
38namespace relay {
39
40using Constructor = tvm::Constructor;
41using ConstructorNode = tvm::ConstructorNode;
42
43using TypeData = tvm::TypeData;
44using TypeDataNode = tvm::TypeDataNode;
45
46/*! \brief Base type for declaring relay pattern. */
47class PatternNode : public RelayNode {
48 public:
49 static constexpr const char* _type_key = "relay.Pattern";
50 static constexpr const bool _type_has_method_sequal_reduce = true;
51 static constexpr const bool _type_has_method_shash_reduce = true;
52 TVM_DECLARE_BASE_OBJECT_INFO(PatternNode, Object);
53};
54
55/*!
56 * \brief Pattern is the base type for an ADT match pattern in Relay.
57 *
58 * Given an ADT value, a pattern might accept it and bind the pattern variable to some value
59 * (typically a subnode of the input or the input). Otherwise, the pattern rejects the value.
60 *
61 * ADT pattern matching thus takes a list of values and binds to the first that accepts the value.
62 */
63class Pattern : public ObjectRef {
64 public:
65 Pattern() {}
66 explicit Pattern(ObjectPtr<tvm::Object> p) : ObjectRef(p) {}
67
68 using ContainerType = PatternNode;
69};
70
71/*! \brief A wildcard pattern: Accepts all input and binds nothing. */
72class PatternWildcard;
73/*! \brief PatternWildcard container node */
74class PatternWildcardNode : public PatternNode {
75 public:
76 void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("span", &span); }
77
78 bool SEqualReduce(const PatternNode* other, SEqualReducer equal) const { return true; }
79
80 void SHashReduce(SHashReducer hash_reduce) const {}
81
82 static constexpr const char* _type_key = "relay.PatternWildcard";
83 TVM_DECLARE_FINAL_OBJECT_INFO(PatternWildcardNode, PatternNode);
84};
85
86class PatternWildcard : public Pattern {
87 public:
88 /* \brief Overload the default constructors. */
89 TVM_DLL PatternWildcard();
90 explicit PatternWildcard(ObjectPtr<Object> n) : Pattern(n) {}
91 /* \brief Copy constructor. */
92 PatternWildcard(const PatternWildcard& pat) : PatternWildcard(pat.data_) {}
93 /* \brief Move constructor. */
94 PatternWildcard(PatternWildcard&& pat) : PatternWildcard(std::move(pat.data_)) {}
95 /* \brief Copy assignment. */
96 PatternWildcard& operator=(const PatternWildcard& other) {
97 (*this).data_ = other.data_;
98 return *this;
99 }
100 /* \brief Move assignment. */
101 PatternWildcard& operator=(PatternWildcard&& other) {
102 (*this).data_ = std::move(other.data_);
103 return *this;
104 }
105
106 const PatternWildcardNode* operator->() const {
107 return static_cast<const PatternWildcardNode*>(get());
108 }
109
110 using ContainerType = PatternWildcardNode;
111};
112
113/*! \brief A var pattern. Accept all input and bind to a var. */
114class PatternVar;
115/*! \brief PatternVar container node */
116class PatternVarNode : public PatternNode {
117 public:
118 /*! \brief Variable that stores the matched value. */
119 tvm::relay::Var var;
120
121 void VisitAttrs(tvm::AttrVisitor* v) {
122 v->Visit("var", &var);
123 v->Visit("span", &span);
124 }
125
126 bool SEqualReduce(const PatternVarNode* other, SEqualReducer equal) const {
127 return equal.DefEqual(var, other->var);
128 }
129
130 void SHashReduce(SHashReducer hash_reduce) const { hash_reduce.DefHash(var); }
131
132 static constexpr const char* _type_key = "relay.PatternVar";
133 TVM_DECLARE_FINAL_OBJECT_INFO(PatternVarNode, PatternNode);
134};
135
136class PatternVar : public Pattern {
137 public:
138 /*!
139 * \brief Constructor
140 * \param var The var to construct a pattern
141 */
142 TVM_DLL explicit PatternVar(tvm::relay::Var var);
143
144 TVM_DEFINE_OBJECT_REF_METHODS(PatternVar, Pattern, PatternVarNode);
145};
146
147/*! \brief A constructor pattern. Matches a value with the given constructor, binds recursively. */
148class PatternConstructor;
149/*! \brief PatternVar container node */
150class PatternConstructorNode : public PatternNode {
151 public:
152 /*! Constructor matched by the pattern. */
153 Constructor constructor;
154 /*! Sub-patterns to match against each input to the constructor. */
155 tvm::Array<Pattern> patterns;
156
157 void VisitAttrs(tvm::AttrVisitor* v) {
158 v->Visit("constructor", &constructor);
159 v->Visit("patterns", &patterns);
160 v->Visit("span", &span);
161 }
162
163 bool SEqualReduce(const PatternConstructorNode* other, SEqualReducer equal) const {
164 return equal(constructor, other->constructor) && equal(patterns, other->patterns);
165 }
166
167 void SHashReduce(SHashReducer hash_reduce) const {
168 hash_reduce(constructor);
169 hash_reduce(patterns);
170 }
171
172 static constexpr const char* _type_key = "relay.PatternConstructor";
173 TVM_DECLARE_FINAL_OBJECT_INFO(PatternConstructorNode, PatternNode);
174};
175
176class PatternConstructor : public Pattern {
177 public:
178 /*!
179 * \brief Constructor
180 * \param constructor The constructor of a pattern
181 * \param patterns The sub-patterns for matching
182 */
183 TVM_DLL PatternConstructor(Constructor constructor, tvm::Array<Pattern> patterns);
184
185 TVM_DEFINE_OBJECT_REF_METHODS(PatternConstructor, Pattern, PatternConstructorNode);
186};
187
188/*! \brief A tuple pattern. Matches a tuple, binds recursively. */
189class PatternTuple;
190/*! \brief PatternVar container node */
191class PatternTupleNode : public PatternNode {
192 public:
193 /* TODO(@jroesch): rename to field_pats */
194 /*! Sub-patterns to match against each value of the tuple. */
195 tvm::Array<Pattern> patterns;
196
197 void VisitAttrs(tvm::AttrVisitor* v) {
198 v->Visit("patterns", &patterns);
199 v->Visit("span", &span);
200 }
201
202 bool SEqualReduce(const PatternTupleNode* other, SEqualReducer equal) const {
203 return equal(patterns, other->patterns);
204 }
205
206 void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(patterns); }
207
208 static constexpr const char* _type_key = "relay.PatternTuple";
209 TVM_DECLARE_FINAL_OBJECT_INFO(PatternTupleNode, PatternNode);
210};
211
212class PatternTuple : public Pattern {
213 public:
214 /*!
215 * \brief Constructor
216 * \param patterns The sub-patterns to match against each value of the tuple
217 */
218 TVM_DLL explicit PatternTuple(tvm::Array<Pattern> patterns);
219
220 TVM_DEFINE_OBJECT_REF_METHODS(PatternTuple, Pattern, PatternTupleNode);
221};
222
223/*! \brief A clause in a match expression. */
224class Clause;
225/*! \brief Clause container node. */
226class ClauseNode : public Object {
227 public:
228 /*! \brief The pattern the clause matches. */
229 Pattern lhs;
230 /*! \brief The resulting value. */
231 Expr rhs;
232
233 void VisitAttrs(tvm::AttrVisitor* v) {
234 v->Visit("lhs", &lhs);
235 v->Visit("rhs", &rhs);
236 }
237
238 bool SEqualReduce(const ClauseNode* other, SEqualReducer equal) const {
239 return equal(lhs, other->lhs) && equal(rhs, other->rhs);
240 }
241
242 void SHashReduce(SHashReducer hash_reduce) const {
243 hash_reduce(lhs);
244 hash_reduce(rhs);
245 }
246
247 static constexpr const char* _type_key = "relay.Clause";
248 static constexpr const bool _type_has_method_sequal_reduce = true;
249 static constexpr const bool _type_has_method_shash_reduce = true;
250 TVM_DECLARE_FINAL_OBJECT_INFO(ClauseNode, Object);
251};
252
253class Clause : public ObjectRef {
254 public:
255 /*!
256 * \brief Constructor
257 * \param lhs The pattern matched by the clause.
258 * \param rhs The resulting value
259 */
260 TVM_DLL explicit Clause(Pattern lhs, Expr rhs);
261
262 TVM_DEFINE_OBJECT_REF_METHODS(Clause, ObjectRef, ClauseNode);
263 TVM_DEFINE_OBJECT_REF_COW_METHOD(ClauseNode);
264};
265
266/*!
267 * \brief Returns \p clause with the given properties. A null property denotes 'no change'.
268 * Returns \p clause if all properties are unchanged. Otherwise, returns a copy with the new
269 * fields.
270 */
271Clause WithFields(Clause clause, Optional<Pattern> opt_lhs = Optional<Pattern>(),
272 Optional<Expr> opt_rhs = Optional<Expr>());
273
274/*! \brief ADT pattern matching exression. */
275class Match;
276/*! \brief Match container node. */
277class MatchNode : public ExprNode {
278 public:
279 /*! \brief The input being deconstructed. */
280 Expr data;
281
282 /*! \brief The match node clauses. */
283 tvm::Array<Clause> clauses;
284
285 /*! \brief Should this match be complete (cover all cases)?
286 * If yes, the type checker will generate an error if there are any missing cases.
287 */
288 bool complete;
289
290 void VisitAttrs(tvm::AttrVisitor* v) {
291 v->Visit("data", &data);
292 v->Visit("clauses", &clauses);
293 v->Visit("complete", &complete);
294 v->Visit("virtual_device_", &virtual_device_);
295 v->Visit("span", &span);
296 v->Visit("_checked_type_", &checked_type_);
297 }
298
299 bool SEqualReduce(const MatchNode* other, SEqualReducer equal) const {
300 equal->MarkGraphNode();
301 return equal(data, other->data) && equal(clauses, other->clauses) &&
302 equal(complete, other->complete);
303 }
304
305 void SHashReduce(SHashReducer hash_reduce) const {
306 hash_reduce->MarkGraphNode();
307 hash_reduce(data);
308 hash_reduce(clauses);
309 hash_reduce(complete);
310 }
311
312 static constexpr const char* _type_key = "relay.Match";
313 TVM_DECLARE_FINAL_OBJECT_INFO(MatchNode, ExprNode);
314};
315
316class Match : public Expr {
317 public:
318 /*!
319 * \brief Constructor
320 * \param data the input being deconstructed.
321 * \param clauses The clauses for matching.
322 * \param complete Indicate if this match is complete.
323 * \param span The span of the expression.
324 */
325 TVM_DLL Match(Expr data, tvm::Array<Clause> clauses, bool complete = true, Span span = Span());
326
327 TVM_DEFINE_OBJECT_REF_METHODS(Match, RelayExpr, MatchNode);
328 TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchNode);
329};
330
331/*!
332 * \brief Returns \p match with the given properties. A null property denotes 'no change'.
333 * Returns \p match if all properties are unchanged. Otherwise, returns a copy with the new
334 * fields.
335 */
336Match WithFields(Match match, Optional<Expr> opt_data = Optional<Expr>(),
337 Optional<Array<Clause>> opt_clauses = Optional<Array<Clause>>(),
338 Optional<Bool> opt_complete = Optional<Bool>(),
339 Optional<Span> opt_span = Optional<Span>());
340
341} // namespace relay
342} // namespace tvm
343
344#endif // TVM_RELAY_ADT_H_
345