1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/relay/dataflow_pattern.h |
22 | * \brief A pattern language for matching dataflow properties. |
23 | */ |
24 | #ifndef TVM_RELAY_DATAFLOW_PATTERN_H_ |
25 | #define TVM_RELAY_DATAFLOW_PATTERN_H_ |
26 | |
27 | #include <tvm/relay/expr.h> |
28 | #include <tvm/relay/type.h> |
29 | |
30 | #include <string> |
31 | #include <vector> |
32 | |
33 | namespace tvm { |
34 | namespace relay { |
35 | |
36 | /*! |
37 | * \brief Base type of all dataflow patterns. |
38 | * \sa DFPattern |
39 | */ |
40 | class DFPatternNode : public Object { |
41 | public: |
42 | static constexpr const char* _type_key = "DFPatternNode" ; |
43 | TVM_DECLARE_BASE_OBJECT_INFO(DFPatternNode, Object); |
44 | }; |
45 | |
46 | /*! |
47 | * \brief Managed reference to dataflow patterns. |
48 | * \sa DFPatternNode |
49 | */ |
50 | class DFPattern : public ObjectRef { |
51 | public: |
52 | /*! \brief Syntatic Sugar for creating a CallPattern */ |
53 | DFPattern operator()(const std::vector<DFPattern>& args) const; |
54 | /*! \brief Syntatic Sugar for creating a CallPattern with an "add" op */ |
55 | DFPattern operator+(const DFPattern& other) const; |
56 | /*! \brief Syntatic Sugar for creating a CallPattern with a "subtract" op */ |
57 | DFPattern operator-(const DFPattern& other) const; |
58 | /*! \brief Syntatic Sugar for creating a CallPattern with a "multiply" op */ |
59 | DFPattern operator*(const DFPattern& other) const; |
60 | /*! \brief Syntatic Sugar for creating a CallPattern with a "divide" op */ |
61 | DFPattern operator/(const DFPattern& other) const; |
62 | /*! \brief Syntatic Sugar for creating an AltPattern */ |
63 | DFPattern operator||(const DFPattern& other) const; |
64 | /*! \brief Syntatic Sugar for creating an Optional Pattern */ |
65 | DFPattern Optional(const std::function<DFPattern(const DFPattern&)>& func) const; |
66 | /*! \brief Syntatic Sugar for creating an AttrPattern */ |
67 | DFPattern HasAttr(const Map<String, ObjectRef>& attrs) const; |
68 | /*! \brief Syntatic Sugar for creating a TypePattern */ |
69 | DFPattern HasType(const Type& type) const; |
70 | /*! \brief Syntatic Sugar for creating a DataTypePattern with a DataType */ |
71 | DFPattern HasDtype(const DataType& dtype) const; |
72 | /*! \brief Syntatic Sugar for creating a DataTypePattern with a data type's name */ |
73 | DFPattern HasDtype(const std::string& dtype) const; |
74 | /*! \brief Syntatic Sugar for creating a ShapePattern */ |
75 | DFPattern HasShape(const Array<PrimExpr> shape) const; |
76 | |
77 | TVM_DEFINE_OBJECT_REF_METHODS(DFPattern, ObjectRef, DFPatternNode); |
78 | }; |
79 | |
80 | /*! |
81 | * \brief Pattern for Relay Expression. |
82 | */ |
83 | class ExprPatternNode : public DFPatternNode { |
84 | public: |
85 | /*! \brief The expression to match. */ |
86 | Expr expr; |
87 | |
88 | void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("expr" , &expr); } |
89 | |
90 | static constexpr const char* _type_key = "relay.dataflow_pattern.ExprPattern" ; |
91 | TVM_DECLARE_FINAL_OBJECT_INFO(ExprPatternNode, DFPatternNode); |
92 | }; |
93 | |
94 | /*! |
95 | * \brief A pattern which matches a literal expression. |
96 | * |
97 | * \note Uses structural equality on expressions to check equality. |
98 | * |
99 | */ |
100 | class ExprPattern : public DFPattern { |
101 | public: |
102 | TVM_DLL explicit ExprPattern(Expr expr); |
103 | TVM_DEFINE_OBJECT_REF_METHODS(ExprPattern, DFPattern, ExprPatternNode); |
104 | }; |
105 | |
106 | /*! |
107 | * \brief A Pattern to Match a Relay Variable |
108 | */ |
109 | class VarPattern; |
110 | /*! \brief Container for Var */ |
111 | class VarPatternNode : public DFPatternNode { |
112 | public: |
113 | /*! |
114 | * \brief The name of the Var (optional). |
115 | */ |
116 | String name; |
117 | |
118 | /*! \return The name hint of the variable */ |
119 | const String& name_hint() const { return name; } |
120 | |
121 | void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name" , &name); } |
122 | |
123 | static constexpr const char* _type_key = "relay.dataflow_pattern.VarPattern" ; |
124 | TVM_DECLARE_FINAL_OBJECT_INFO(VarPatternNode, DFPatternNode); |
125 | }; |
126 | |
127 | class VarPattern : public DFPattern { |
128 | public: |
129 | TVM_DLL VarPattern(String name_hint); |
130 | TVM_DEFINE_OBJECT_REF_METHODS(VarPattern, DFPattern, VarPatternNode); |
131 | }; |
132 | |
133 | /*! |
134 | * \brief A Pattern to Match a Relay Constant |
135 | */ |
136 | class ConstantPattern; |
137 | /*! \brief Container for Constant */ |
138 | class ConstantPatternNode : public DFPatternNode { |
139 | public: |
140 | void VisitAttrs(tvm::AttrVisitor* v) {} |
141 | |
142 | static constexpr const char* _type_key = "relay.dataflow_pattern.ConstantPattern" ; |
143 | TVM_DECLARE_FINAL_OBJECT_INFO(ConstantPatternNode, DFPatternNode); |
144 | }; |
145 | |
146 | class ConstantPattern : public DFPattern { |
147 | public: |
148 | TVM_DEFINE_OBJECT_REF_METHODS(ConstantPattern, DFPattern, ConstantPatternNode); |
149 | }; |
150 | |
151 | /*! |
152 | * \brief Call corresponds to operator invocation. |
153 | * Corresponds to the operator in computational graph terminology. |
154 | */ |
155 | class CallPattern; |
156 | /*! \brief CallPattern container. */ |
157 | class CallPatternNode : public DFPatternNode { |
158 | public: |
159 | /*! |
160 | * \brief The operator(function) being invoked |
161 | * |
162 | * - It can be relay::Op which corresponds to the primitive operators. |
163 | * - It can also be user defined functions (Function, GlobalVar, Var). |
164 | */ |
165 | DFPattern op; |
166 | |
167 | /*! \brief The arguments(inputs) of the call */ |
168 | tvm::Array<relay::DFPattern> args; |
169 | |
170 | void VisitAttrs(tvm::AttrVisitor* v) { |
171 | v->Visit("op" , &op); |
172 | v->Visit("args" , &args); |
173 | } |
174 | |
175 | static constexpr const char* _type_key = "relay.dataflow_pattern.CallPattern" ; |
176 | TVM_DECLARE_FINAL_OBJECT_INFO(CallPatternNode, DFPatternNode); |
177 | }; |
178 | |
179 | class CallPattern : public DFPattern { |
180 | public: |
181 | TVM_DLL CallPattern(DFPattern op, Array<DFPattern> args); |
182 | TVM_DEFINE_OBJECT_REF_METHODS(CallPattern, DFPattern, CallPatternNode); |
183 | }; |
184 | |
185 | /*! |
186 | * \brief Relay Function container |
187 | * \sa Function |
188 | */ |
189 | class FunctionPatternNode : public DFPatternNode { |
190 | public: |
191 | /*! \brief Function parameters */ |
192 | tvm::Array<DFPattern> params; |
193 | /*! |
194 | * \brief |
195 | * The expression which represents the computation of the function, |
196 | * the expression may reference the parameters, and the type of it |
197 | * or sub-expressions may reference the type variables. |
198 | */ |
199 | DFPattern body; |
200 | |
201 | void VisitAttrs(tvm::AttrVisitor* v) { |
202 | v->Visit("params" , ¶ms); |
203 | v->Visit("body" , &body); |
204 | } |
205 | |
206 | static constexpr const char* _type_key = "relay.dataflow_pattern.FunctionPattern" ; |
207 | TVM_DECLARE_FINAL_OBJECT_INFO(FunctionPatternNode, DFPatternNode); |
208 | }; |
209 | |
210 | /*! |
211 | * \brief Managed reference to FunctionNode. |
212 | * \sa FunctionNode |
213 | */ |
214 | class FunctionPattern : public DFPattern { |
215 | public: |
216 | /*! |
217 | * \brief Constructor |
218 | * \param params The parameters of the function. |
219 | * \param body The body of the function. |
220 | */ |
221 | TVM_DLL FunctionPattern(tvm::Array<DFPattern> params, DFPattern body); |
222 | |
223 | TVM_DEFINE_OBJECT_REF_METHODS(FunctionPattern, DFPattern, FunctionPatternNode); |
224 | TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionPatternNode); |
225 | }; |
226 | |
227 | /*! \brief A binding of a sub-network. */ |
228 | class LetPatternNode : public DFPatternNode { |
229 | public: |
230 | /*! \brief The variable we bind to */ |
231 | DFPattern var; |
232 | /*! \brief The value we bind var to */ |
233 | DFPattern value; |
234 | /*! \brief The body of the let binding */ |
235 | DFPattern body; |
236 | |
237 | void VisitAttrs(tvm::AttrVisitor* v) { |
238 | v->Visit("var" , &var); |
239 | v->Visit("value" , &value); |
240 | v->Visit("body" , &body); |
241 | } |
242 | |
243 | static constexpr const char* _type_key = "relay.dataflow_pattern.LetPattern" ; |
244 | TVM_DECLARE_FINAL_OBJECT_INFO(LetPatternNode, DFPatternNode); |
245 | }; |
246 | |
247 | /*! |
248 | * \brief Let binding that binds a local var |
249 | */ |
250 | class LetPattern : public DFPattern { |
251 | public: |
252 | /*! |
253 | * \brief The constructor |
254 | * \param var The variable that is bound to. |
255 | * \param value The value used to bind to the variable. |
256 | * \param body The body of the let binding. |
257 | */ |
258 | TVM_DLL LetPattern(DFPattern var, DFPattern value, DFPattern body); |
259 | |
260 | TVM_DEFINE_OBJECT_REF_METHODS(LetPattern, DFPattern, LetPatternNode); |
261 | }; |
262 | |
263 | /*! \brief Tuple of multiple Exprs */ |
264 | class TuplePattern; |
265 | /*! \brief Tuple container */ |
266 | class TuplePatternNode : public DFPatternNode { |
267 | public: |
268 | /*! \brief the fields of the tuple */ |
269 | tvm::Array<DFPattern> fields; |
270 | |
271 | void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields" , &fields); } |
272 | |
273 | static constexpr const char* _type_key = "relay.dataflow_pattern.TuplePattern" ; |
274 | TVM_DECLARE_FINAL_OBJECT_INFO(TuplePatternNode, DFPatternNode); |
275 | }; |
276 | |
277 | class TuplePattern : public DFPattern { |
278 | public: |
279 | TVM_DLL explicit TuplePattern(tvm::Array<DFPattern> fields); |
280 | TVM_DEFINE_OBJECT_REF_METHODS(TuplePattern, DFPattern, TuplePatternNode); |
281 | }; |
282 | |
283 | /*! \brief Get index-th field out of a tuple. */ |
284 | class TupleGetItemPattern; |
285 | class TupleGetItemPatternNode : public DFPatternNode { |
286 | public: |
287 | /*! \brief The tuple Expression */ |
288 | DFPattern tuple; |
289 | /*! \brief which value to get */ |
290 | int index; |
291 | |
292 | void VisitAttrs(tvm::AttrVisitor* v) { |
293 | v->Visit("tuple" , &tuple); |
294 | v->Visit("index" , &index); |
295 | } |
296 | |
297 | static constexpr const char* _type_key = "relay.dataflow_pattern.TupleGetItemPattern" ; |
298 | TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemPatternNode, DFPatternNode); |
299 | }; |
300 | |
301 | class IfPatternNode : public DFPatternNode { |
302 | public: |
303 | DFPattern cond, true_branch, false_branch; |
304 | |
305 | void VisitAttrs(tvm::AttrVisitor* v) { |
306 | v->Visit("cond" , &cond); |
307 | v->Visit("true_branch" , &true_branch); |
308 | v->Visit("false_branch" , &false_branch); |
309 | } |
310 | |
311 | static constexpr const char* _type_key = "relay.dataflow_pattern.IfPattern" ; |
312 | TVM_DECLARE_FINAL_OBJECT_INFO(IfPatternNode, DFPatternNode); |
313 | }; |
314 | |
315 | class IfPattern : public DFPattern { |
316 | public: |
317 | TVM_DLL IfPattern(DFPattern cond, DFPattern then_clause, DFPattern else_clause); |
318 | TVM_DEFINE_OBJECT_REF_METHODS(IfPattern, DFPattern, IfPatternNode); |
319 | }; |
320 | |
321 | class TupleGetItemPattern : public DFPattern { |
322 | public: |
323 | TVM_DLL TupleGetItemPattern(DFPattern tuple, int index); |
324 | TVM_DEFINE_OBJECT_REF_METHODS(TupleGetItemPattern, DFPattern, TupleGetItemPatternNode); |
325 | }; |
326 | |
327 | class AltPattern; |
328 | /*! |
329 | * \brief Pattern for Alternate Expressions. |
330 | */ |
331 | class AltPatternNode : public DFPatternNode { |
332 | public: |
333 | /*! \brief The left optional pattern. */ |
334 | DFPattern left; |
335 | /*! \brief The right optional pattern. */ |
336 | DFPattern right; |
337 | |
338 | void VisitAttrs(tvm::AttrVisitor* v) { |
339 | v->Visit("left" , &left); |
340 | v->Visit("right" , &right); |
341 | } |
342 | |
343 | static constexpr const char* _type_key = "relay.dataflow_pattern.AltPattern" ; |
344 | TVM_DECLARE_FINAL_OBJECT_INFO(AltPatternNode, DFPatternNode); |
345 | }; |
346 | |
347 | /*! |
348 | * \brief A pattern which matches either of two patterns |
349 | */ |
350 | class AltPattern : public DFPattern { |
351 | public: |
352 | TVM_DLL AltPattern(DFPattern left, DFPattern right); |
353 | TVM_DEFINE_OBJECT_REF_METHODS(AltPattern, DFPattern, AltPatternNode); |
354 | }; |
355 | |
356 | /*! |
357 | * \brief Wildcard Pattern. |
358 | */ |
359 | class WildcardPatternNode : public DFPatternNode { |
360 | public: |
361 | void VisitAttrs(tvm::AttrVisitor* v) {} |
362 | |
363 | static constexpr const char* _type_key = "relay.dataflow_pattern.WildcardPattern" ; |
364 | TVM_DECLARE_FINAL_OBJECT_INFO(WildcardPatternNode, DFPatternNode); |
365 | }; |
366 | |
367 | /*! |
368 | * \brief A pattern which matches anything. |
369 | */ |
370 | class WildcardPattern : public DFPattern { |
371 | public: |
372 | TVM_DEFINE_OBJECT_REF_METHODS(WildcardPattern, DFPattern, WildcardPatternNode); |
373 | }; |
374 | |
375 | class TypePattern; |
376 | /*! |
377 | * \brief Pattern for Types. |
378 | */ |
379 | class TypePatternNode : public DFPatternNode { |
380 | public: |
381 | /*! \brief The pattern. */ |
382 | DFPattern pattern; |
383 | /*! \brief The type to match */ |
384 | Type type; |
385 | |
386 | void VisitAttrs(tvm::AttrVisitor* v) { |
387 | v->Visit("pattern" , &pattern); |
388 | v->Visit("type" , &type); |
389 | } |
390 | |
391 | static constexpr const char* _type_key = "relay.dataflow_pattern.TypePattern" ; |
392 | TVM_DECLARE_FINAL_OBJECT_INFO(TypePatternNode, DFPatternNode); |
393 | }; |
394 | |
395 | /*! |
396 | * \brief A pattern which matches a type in another pattern |
397 | */ |
398 | class TypePattern : public DFPattern { |
399 | public: |
400 | TVM_DLL TypePattern(DFPattern pattern, Type type); |
401 | TVM_DEFINE_OBJECT_REF_METHODS(TypePattern, DFPattern, TypePatternNode); |
402 | }; |
403 | |
404 | class ShapePattern; |
405 | /*! |
406 | * \brief Pattern for Shapes. |
407 | */ |
408 | class ShapePatternNode : public DFPatternNode { |
409 | public: |
410 | /*! \brief The pattern. */ |
411 | DFPattern pattern; |
412 | /*! \brief The type to match */ |
413 | Array<PrimExpr> shape; |
414 | |
415 | void VisitAttrs(tvm::AttrVisitor* v) { |
416 | v->Visit("pattern" , &pattern); |
417 | v->Visit("shape" , &shape); |
418 | } |
419 | |
420 | static constexpr const char* _type_key = "relay.dataflow_pattern.ShapePattern" ; |
421 | TVM_DECLARE_FINAL_OBJECT_INFO(ShapePatternNode, DFPatternNode); |
422 | }; |
423 | |
424 | /*! |
425 | * \brief A pattern which matches a type in another pattern |
426 | */ |
427 | class ShapePattern : public DFPattern { |
428 | public: |
429 | TVM_DLL ShapePattern(DFPattern pattern, Array<PrimExpr> type); |
430 | TVM_DEFINE_OBJECT_REF_METHODS(ShapePattern, DFPattern, ShapePatternNode); |
431 | }; |
432 | |
433 | class DataTypePattern; |
434 | /*! |
435 | * \brief Pattern for Types. |
436 | */ |
437 | class DataTypePatternNode : public DFPatternNode { |
438 | public: |
439 | /*! \brief The pattern. */ |
440 | DFPattern pattern; |
441 | /*! \brief The type to match */ |
442 | DataType dtype; |
443 | |
444 | void VisitAttrs(tvm::AttrVisitor* v) { |
445 | v->Visit("pattern" , &pattern); |
446 | v->Visit("dtype" , &dtype); |
447 | } |
448 | |
449 | static constexpr const char* _type_key = "relay.dataflow_pattern.DataTypePattern" ; |
450 | TVM_DECLARE_FINAL_OBJECT_INFO(DataTypePatternNode, DFPatternNode); |
451 | }; |
452 | |
453 | /*! |
454 | * \brief A pattern which matches a type in another pattern |
455 | */ |
456 | class DataTypePattern : public DFPattern { |
457 | public: |
458 | TVM_DLL DataTypePattern(DFPattern pattern, DataType dtype); |
459 | TVM_DEFINE_OBJECT_REF_METHODS(DataTypePattern, DFPattern, DataTypePatternNode); |
460 | }; |
461 | |
462 | class AttrPattern; |
463 | /*! |
464 | * \brief Pattern for Attributes. |
465 | */ |
466 | class AttrPatternNode : public DFPatternNode { |
467 | public: |
468 | /*! \brief The pattern. */ |
469 | DFPattern pattern; |
470 | /*! \brief The attribute to match */ |
471 | DictAttrs attrs; |
472 | |
473 | void VisitAttrs(tvm::AttrVisitor* v) { |
474 | v->Visit("pattern" , &pattern); |
475 | v->Visit("attrs" , &attrs); |
476 | } |
477 | |
478 | static constexpr const char* _type_key = "relay.dataflow_pattern.AttrPattern" ; |
479 | TVM_DECLARE_FINAL_OBJECT_INFO(AttrPatternNode, DFPatternNode); |
480 | }; |
481 | |
482 | /*! |
483 | * \brief A pattern which matches attributes in another pattern |
484 | */ |
485 | class AttrPattern : public DFPattern { |
486 | public: |
487 | TVM_DLL AttrPattern(DFPattern pattern, DictAttrs attrs); |
488 | TVM_DEFINE_OBJECT_REF_METHODS(AttrPattern, DFPattern, AttrPatternNode); |
489 | }; |
490 | |
491 | class DominatorPattern; |
492 | /*! |
493 | * \brief Dominated Graph Pattern |
494 | * Pattern for fuzzy subgraphs where all outputs of the parent are used finally by the child, and |
495 | * every operation between the parent and the child matches the path. |
496 | */ |
497 | class DominatorPatternNode : public DFPatternNode { |
498 | public: |
499 | /*! \brief The parent. */ |
500 | DFPattern parent; |
501 | /*! \brief The path. */ |
502 | DFPattern path; |
503 | /*! \brief The child. */ |
504 | DFPattern child; |
505 | |
506 | void VisitAttrs(tvm::AttrVisitor* v) { |
507 | v->Visit("parent" , &parent); |
508 | v->Visit("path" , &path); |
509 | v->Visit("child" , &child); |
510 | } |
511 | |
512 | static constexpr const char* _type_key = "relay.dataflow_pattern.DominatorPattern" ; |
513 | TVM_DECLARE_FINAL_OBJECT_INFO(DominatorPatternNode, DFPatternNode); |
514 | }; |
515 | |
516 | /*! |
517 | * \brief A pattern which matches a variable length dominator path |
518 | */ |
519 | class DominatorPattern : public DFPattern { |
520 | public: |
521 | TVM_DLL DominatorPattern(DFPattern parent, DFPattern path, DFPattern child); |
522 | TVM_DEFINE_OBJECT_REF_METHODS(DominatorPattern, DFPattern, DominatorPatternNode); |
523 | }; |
524 | |
525 | /*! \brief Syntatic Sugar for creating a VarPattern with a name */ |
526 | DFPattern IsVar(const String& name); |
527 | /*! \brief Syntatic Sugar for creating a ConstantPattern */ |
528 | DFPattern IsConstant(); |
529 | /*! \brief Syntatic Sugar for creating a WildcardPattern */ |
530 | DFPattern IsWildcard(); |
531 | /*! \brief Syntatic Sugar for creating a ExprPattern */ |
532 | DFPattern IsExpr(const Expr& expr); |
533 | /*! \brief Syntatic Sugar for creating a ExprPattern base on an Op*/ |
534 | DFPattern IsOp(const String& op_name); |
535 | /*! \brief Syntatic Sugar for creating a TuplePattern*/ |
536 | DFPattern IsTuple(const Array<DFPattern>& fields); |
537 | /*! \brief Syntatic Sugar for creating a TupleGetItemPattern*/ |
538 | DFPattern IsTupleGetItem(const DFPattern tuple, int index = -1); |
539 | |
540 | } // namespace relay |
541 | } // namespace tvm |
542 | #endif // TVM_RELAY_DATAFLOW_PATTERN_H_ |
543 | |