1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/relay/dataflow_pattern.h
22 * \brief A pattern language for matching dataflow properties.
23 */
24#ifndef TVM_RELAY_DATAFLOW_PATTERN_H_
25#define TVM_RELAY_DATAFLOW_PATTERN_H_
26
27#include <tvm/relay/expr.h>
28#include <tvm/relay/type.h>
29
30#include <string>
31#include <vector>
32
33namespace tvm {
34namespace relay {
35
36/*!
37 * \brief Base type of all dataflow patterns.
38 * \sa DFPattern
39 */
40class DFPatternNode : public Object {
41 public:
42 static constexpr const char* _type_key = "DFPatternNode";
43 TVM_DECLARE_BASE_OBJECT_INFO(DFPatternNode, Object);
44};
45
46/*!
47 * \brief Managed reference to dataflow patterns.
48 * \sa DFPatternNode
49 */
50class DFPattern : public ObjectRef {
51 public:
52 /*! \brief Syntatic Sugar for creating a CallPattern */
53 DFPattern operator()(const std::vector<DFPattern>& args) const;
54 /*! \brief Syntatic Sugar for creating a CallPattern with an "add" op */
55 DFPattern operator+(const DFPattern& other) const;
56 /*! \brief Syntatic Sugar for creating a CallPattern with a "subtract" op */
57 DFPattern operator-(const DFPattern& other) const;
58 /*! \brief Syntatic Sugar for creating a CallPattern with a "multiply" op */
59 DFPattern operator*(const DFPattern& other) const;
60 /*! \brief Syntatic Sugar for creating a CallPattern with a "divide" op */
61 DFPattern operator/(const DFPattern& other) const;
62 /*! \brief Syntatic Sugar for creating an AltPattern */
63 DFPattern operator||(const DFPattern& other) const;
64 /*! \brief Syntatic Sugar for creating an Optional Pattern */
65 DFPattern Optional(const std::function<DFPattern(const DFPattern&)>& func) const;
66 /*! \brief Syntatic Sugar for creating an AttrPattern */
67 DFPattern HasAttr(const Map<String, ObjectRef>& attrs) const;
68 /*! \brief Syntatic Sugar for creating a TypePattern */
69 DFPattern HasType(const Type& type) const;
70 /*! \brief Syntatic Sugar for creating a DataTypePattern with a DataType */
71 DFPattern HasDtype(const DataType& dtype) const;
72 /*! \brief Syntatic Sugar for creating a DataTypePattern with a data type's name */
73 DFPattern HasDtype(const std::string& dtype) const;
74 /*! \brief Syntatic Sugar for creating a ShapePattern */
75 DFPattern HasShape(const Array<PrimExpr> shape) const;
76
77 TVM_DEFINE_OBJECT_REF_METHODS(DFPattern, ObjectRef, DFPatternNode);
78};
79
80/*!
81 * \brief Pattern for Relay Expression.
82 */
83class ExprPatternNode : public DFPatternNode {
84 public:
85 /*! \brief The expression to match. */
86 Expr expr;
87
88 void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("expr", &expr); }
89
90 static constexpr const char* _type_key = "relay.dataflow_pattern.ExprPattern";
91 TVM_DECLARE_FINAL_OBJECT_INFO(ExprPatternNode, DFPatternNode);
92};
93
94/*!
95 * \brief A pattern which matches a literal expression.
96 *
97 * \note Uses structural equality on expressions to check equality.
98 *
99 */
100class ExprPattern : public DFPattern {
101 public:
102 TVM_DLL explicit ExprPattern(Expr expr);
103 TVM_DEFINE_OBJECT_REF_METHODS(ExprPattern, DFPattern, ExprPatternNode);
104};
105
106/*!
107 * \brief A Pattern to Match a Relay Variable
108 */
109class VarPattern;
110/*! \brief Container for Var */
111class VarPatternNode : public DFPatternNode {
112 public:
113 /*!
114 * \brief The name of the Var (optional).
115 */
116 String name;
117
118 /*! \return The name hint of the variable */
119 const String& name_hint() const { return name; }
120
121 void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name", &name); }
122
123 static constexpr const char* _type_key = "relay.dataflow_pattern.VarPattern";
124 TVM_DECLARE_FINAL_OBJECT_INFO(VarPatternNode, DFPatternNode);
125};
126
127class VarPattern : public DFPattern {
128 public:
129 TVM_DLL VarPattern(String name_hint);
130 TVM_DEFINE_OBJECT_REF_METHODS(VarPattern, DFPattern, VarPatternNode);
131};
132
133/*!
134 * \brief A Pattern to Match a Relay Constant
135 */
136class ConstantPattern;
137/*! \brief Container for Constant */
138class ConstantPatternNode : public DFPatternNode {
139 public:
140 void VisitAttrs(tvm::AttrVisitor* v) {}
141
142 static constexpr const char* _type_key = "relay.dataflow_pattern.ConstantPattern";
143 TVM_DECLARE_FINAL_OBJECT_INFO(ConstantPatternNode, DFPatternNode);
144};
145
146class ConstantPattern : public DFPattern {
147 public:
148 TVM_DEFINE_OBJECT_REF_METHODS(ConstantPattern, DFPattern, ConstantPatternNode);
149};
150
151/*!
152 * \brief Call corresponds to operator invocation.
153 * Corresponds to the operator in computational graph terminology.
154 */
155class CallPattern;
156/*! \brief CallPattern container. */
157class CallPatternNode : public DFPatternNode {
158 public:
159 /*!
160 * \brief The operator(function) being invoked
161 *
162 * - It can be relay::Op which corresponds to the primitive operators.
163 * - It can also be user defined functions (Function, GlobalVar, Var).
164 */
165 DFPattern op;
166
167 /*! \brief The arguments(inputs) of the call */
168 tvm::Array<relay::DFPattern> args;
169
170 void VisitAttrs(tvm::AttrVisitor* v) {
171 v->Visit("op", &op);
172 v->Visit("args", &args);
173 }
174
175 static constexpr const char* _type_key = "relay.dataflow_pattern.CallPattern";
176 TVM_DECLARE_FINAL_OBJECT_INFO(CallPatternNode, DFPatternNode);
177};
178
179class CallPattern : public DFPattern {
180 public:
181 TVM_DLL CallPattern(DFPattern op, Array<DFPattern> args);
182 TVM_DEFINE_OBJECT_REF_METHODS(CallPattern, DFPattern, CallPatternNode);
183};
184
185/*!
186 * \brief Relay Function container
187 * \sa Function
188 */
189class FunctionPatternNode : public DFPatternNode {
190 public:
191 /*! \brief Function parameters */
192 tvm::Array<DFPattern> params;
193 /*!
194 * \brief
195 * The expression which represents the computation of the function,
196 * the expression may reference the parameters, and the type of it
197 * or sub-expressions may reference the type variables.
198 */
199 DFPattern body;
200
201 void VisitAttrs(tvm::AttrVisitor* v) {
202 v->Visit("params", &params);
203 v->Visit("body", &body);
204 }
205
206 static constexpr const char* _type_key = "relay.dataflow_pattern.FunctionPattern";
207 TVM_DECLARE_FINAL_OBJECT_INFO(FunctionPatternNode, DFPatternNode);
208};
209
210/*!
211 * \brief Managed reference to FunctionNode.
212 * \sa FunctionNode
213 */
214class FunctionPattern : public DFPattern {
215 public:
216 /*!
217 * \brief Constructor
218 * \param params The parameters of the function.
219 * \param body The body of the function.
220 */
221 TVM_DLL FunctionPattern(tvm::Array<DFPattern> params, DFPattern body);
222
223 TVM_DEFINE_OBJECT_REF_METHODS(FunctionPattern, DFPattern, FunctionPatternNode);
224 TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionPatternNode);
225};
226
227/*! \brief A binding of a sub-network. */
228class LetPatternNode : public DFPatternNode {
229 public:
230 /*! \brief The variable we bind to */
231 DFPattern var;
232 /*! \brief The value we bind var to */
233 DFPattern value;
234 /*! \brief The body of the let binding */
235 DFPattern body;
236
237 void VisitAttrs(tvm::AttrVisitor* v) {
238 v->Visit("var", &var);
239 v->Visit("value", &value);
240 v->Visit("body", &body);
241 }
242
243 static constexpr const char* _type_key = "relay.dataflow_pattern.LetPattern";
244 TVM_DECLARE_FINAL_OBJECT_INFO(LetPatternNode, DFPatternNode);
245};
246
247/*!
248 * \brief Let binding that binds a local var
249 */
250class LetPattern : public DFPattern {
251 public:
252 /*!
253 * \brief The constructor
254 * \param var The variable that is bound to.
255 * \param value The value used to bind to the variable.
256 * \param body The body of the let binding.
257 */
258 TVM_DLL LetPattern(DFPattern var, DFPattern value, DFPattern body);
259
260 TVM_DEFINE_OBJECT_REF_METHODS(LetPattern, DFPattern, LetPatternNode);
261};
262
263/*! \brief Tuple of multiple Exprs */
264class TuplePattern;
265/*! \brief Tuple container */
266class TuplePatternNode : public DFPatternNode {
267 public:
268 /*! \brief the fields of the tuple */
269 tvm::Array<DFPattern> fields;
270
271 void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); }
272
273 static constexpr const char* _type_key = "relay.dataflow_pattern.TuplePattern";
274 TVM_DECLARE_FINAL_OBJECT_INFO(TuplePatternNode, DFPatternNode);
275};
276
277class TuplePattern : public DFPattern {
278 public:
279 TVM_DLL explicit TuplePattern(tvm::Array<DFPattern> fields);
280 TVM_DEFINE_OBJECT_REF_METHODS(TuplePattern, DFPattern, TuplePatternNode);
281};
282
283/*! \brief Get index-th field out of a tuple. */
284class TupleGetItemPattern;
285class TupleGetItemPatternNode : public DFPatternNode {
286 public:
287 /*! \brief The tuple Expression */
288 DFPattern tuple;
289 /*! \brief which value to get */
290 int index;
291
292 void VisitAttrs(tvm::AttrVisitor* v) {
293 v->Visit("tuple", &tuple);
294 v->Visit("index", &index);
295 }
296
297 static constexpr const char* _type_key = "relay.dataflow_pattern.TupleGetItemPattern";
298 TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemPatternNode, DFPatternNode);
299};
300
301class IfPatternNode : public DFPatternNode {
302 public:
303 DFPattern cond, true_branch, false_branch;
304
305 void VisitAttrs(tvm::AttrVisitor* v) {
306 v->Visit("cond", &cond);
307 v->Visit("true_branch", &true_branch);
308 v->Visit("false_branch", &false_branch);
309 }
310
311 static constexpr const char* _type_key = "relay.dataflow_pattern.IfPattern";
312 TVM_DECLARE_FINAL_OBJECT_INFO(IfPatternNode, DFPatternNode);
313};
314
315class IfPattern : public DFPattern {
316 public:
317 TVM_DLL IfPattern(DFPattern cond, DFPattern then_clause, DFPattern else_clause);
318 TVM_DEFINE_OBJECT_REF_METHODS(IfPattern, DFPattern, IfPatternNode);
319};
320
321class TupleGetItemPattern : public DFPattern {
322 public:
323 TVM_DLL TupleGetItemPattern(DFPattern tuple, int index);
324 TVM_DEFINE_OBJECT_REF_METHODS(TupleGetItemPattern, DFPattern, TupleGetItemPatternNode);
325};
326
327class AltPattern;
328/*!
329 * \brief Pattern for Alternate Expressions.
330 */
331class AltPatternNode : public DFPatternNode {
332 public:
333 /*! \brief The left optional pattern. */
334 DFPattern left;
335 /*! \brief The right optional pattern. */
336 DFPattern right;
337
338 void VisitAttrs(tvm::AttrVisitor* v) {
339 v->Visit("left", &left);
340 v->Visit("right", &right);
341 }
342
343 static constexpr const char* _type_key = "relay.dataflow_pattern.AltPattern";
344 TVM_DECLARE_FINAL_OBJECT_INFO(AltPatternNode, DFPatternNode);
345};
346
347/*!
348 * \brief A pattern which matches either of two patterns
349 */
350class AltPattern : public DFPattern {
351 public:
352 TVM_DLL AltPattern(DFPattern left, DFPattern right);
353 TVM_DEFINE_OBJECT_REF_METHODS(AltPattern, DFPattern, AltPatternNode);
354};
355
356/*!
357 * \brief Wildcard Pattern.
358 */
359class WildcardPatternNode : public DFPatternNode {
360 public:
361 void VisitAttrs(tvm::AttrVisitor* v) {}
362
363 static constexpr const char* _type_key = "relay.dataflow_pattern.WildcardPattern";
364 TVM_DECLARE_FINAL_OBJECT_INFO(WildcardPatternNode, DFPatternNode);
365};
366
367/*!
368 * \brief A pattern which matches anything.
369 */
370class WildcardPattern : public DFPattern {
371 public:
372 TVM_DEFINE_OBJECT_REF_METHODS(WildcardPattern, DFPattern, WildcardPatternNode);
373};
374
375class TypePattern;
376/*!
377 * \brief Pattern for Types.
378 */
379class TypePatternNode : public DFPatternNode {
380 public:
381 /*! \brief The pattern. */
382 DFPattern pattern;
383 /*! \brief The type to match */
384 Type type;
385
386 void VisitAttrs(tvm::AttrVisitor* v) {
387 v->Visit("pattern", &pattern);
388 v->Visit("type", &type);
389 }
390
391 static constexpr const char* _type_key = "relay.dataflow_pattern.TypePattern";
392 TVM_DECLARE_FINAL_OBJECT_INFO(TypePatternNode, DFPatternNode);
393};
394
395/*!
396 * \brief A pattern which matches a type in another pattern
397 */
398class TypePattern : public DFPattern {
399 public:
400 TVM_DLL TypePattern(DFPattern pattern, Type type);
401 TVM_DEFINE_OBJECT_REF_METHODS(TypePattern, DFPattern, TypePatternNode);
402};
403
404class ShapePattern;
405/*!
406 * \brief Pattern for Shapes.
407 */
408class ShapePatternNode : public DFPatternNode {
409 public:
410 /*! \brief The pattern. */
411 DFPattern pattern;
412 /*! \brief The type to match */
413 Array<PrimExpr> shape;
414
415 void VisitAttrs(tvm::AttrVisitor* v) {
416 v->Visit("pattern", &pattern);
417 v->Visit("shape", &shape);
418 }
419
420 static constexpr const char* _type_key = "relay.dataflow_pattern.ShapePattern";
421 TVM_DECLARE_FINAL_OBJECT_INFO(ShapePatternNode, DFPatternNode);
422};
423
424/*!
425 * \brief A pattern which matches a type in another pattern
426 */
427class ShapePattern : public DFPattern {
428 public:
429 TVM_DLL ShapePattern(DFPattern pattern, Array<PrimExpr> type);
430 TVM_DEFINE_OBJECT_REF_METHODS(ShapePattern, DFPattern, ShapePatternNode);
431};
432
433class DataTypePattern;
434/*!
435 * \brief Pattern for Types.
436 */
437class DataTypePatternNode : public DFPatternNode {
438 public:
439 /*! \brief The pattern. */
440 DFPattern pattern;
441 /*! \brief The type to match */
442 DataType dtype;
443
444 void VisitAttrs(tvm::AttrVisitor* v) {
445 v->Visit("pattern", &pattern);
446 v->Visit("dtype", &dtype);
447 }
448
449 static constexpr const char* _type_key = "relay.dataflow_pattern.DataTypePattern";
450 TVM_DECLARE_FINAL_OBJECT_INFO(DataTypePatternNode, DFPatternNode);
451};
452
453/*!
454 * \brief A pattern which matches a type in another pattern
455 */
456class DataTypePattern : public DFPattern {
457 public:
458 TVM_DLL DataTypePattern(DFPattern pattern, DataType dtype);
459 TVM_DEFINE_OBJECT_REF_METHODS(DataTypePattern, DFPattern, DataTypePatternNode);
460};
461
462class AttrPattern;
463/*!
464 * \brief Pattern for Attributes.
465 */
466class AttrPatternNode : public DFPatternNode {
467 public:
468 /*! \brief The pattern. */
469 DFPattern pattern;
470 /*! \brief The attribute to match */
471 DictAttrs attrs;
472
473 void VisitAttrs(tvm::AttrVisitor* v) {
474 v->Visit("pattern", &pattern);
475 v->Visit("attrs", &attrs);
476 }
477
478 static constexpr const char* _type_key = "relay.dataflow_pattern.AttrPattern";
479 TVM_DECLARE_FINAL_OBJECT_INFO(AttrPatternNode, DFPatternNode);
480};
481
482/*!
483 * \brief A pattern which matches attributes in another pattern
484 */
485class AttrPattern : public DFPattern {
486 public:
487 TVM_DLL AttrPattern(DFPattern pattern, DictAttrs attrs);
488 TVM_DEFINE_OBJECT_REF_METHODS(AttrPattern, DFPattern, AttrPatternNode);
489};
490
491class DominatorPattern;
492/*!
493 * \brief Dominated Graph Pattern
494 * Pattern for fuzzy subgraphs where all outputs of the parent are used finally by the child, and
495 * every operation between the parent and the child matches the path.
496 */
497class DominatorPatternNode : public DFPatternNode {
498 public:
499 /*! \brief The parent. */
500 DFPattern parent;
501 /*! \brief The path. */
502 DFPattern path;
503 /*! \brief The child. */
504 DFPattern child;
505
506 void VisitAttrs(tvm::AttrVisitor* v) {
507 v->Visit("parent", &parent);
508 v->Visit("path", &path);
509 v->Visit("child", &child);
510 }
511
512 static constexpr const char* _type_key = "relay.dataflow_pattern.DominatorPattern";
513 TVM_DECLARE_FINAL_OBJECT_INFO(DominatorPatternNode, DFPatternNode);
514};
515
516/*!
517 * \brief A pattern which matches a variable length dominator path
518 */
519class DominatorPattern : public DFPattern {
520 public:
521 TVM_DLL DominatorPattern(DFPattern parent, DFPattern path, DFPattern child);
522 TVM_DEFINE_OBJECT_REF_METHODS(DominatorPattern, DFPattern, DominatorPatternNode);
523};
524
525/*! \brief Syntatic Sugar for creating a VarPattern with a name */
526DFPattern IsVar(const String& name);
527/*! \brief Syntatic Sugar for creating a ConstantPattern */
528DFPattern IsConstant();
529/*! \brief Syntatic Sugar for creating a WildcardPattern */
530DFPattern IsWildcard();
531/*! \brief Syntatic Sugar for creating a ExprPattern */
532DFPattern IsExpr(const Expr& expr);
533/*! \brief Syntatic Sugar for creating a ExprPattern base on an Op*/
534DFPattern IsOp(const String& op_name);
535/*! \brief Syntatic Sugar for creating a TuplePattern*/
536DFPattern IsTuple(const Array<DFPattern>& fields);
537/*! \brief Syntatic Sugar for creating a TupleGetItemPattern*/
538DFPattern IsTupleGetItem(const DFPattern tuple, int index = -1);
539
540} // namespace relay
541} // namespace tvm
542#endif // TVM_RELAY_DATAFLOW_PATTERN_H_
543