1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/relay/dataflow_matcher.h |
22 | * \brief A pattern matcher for matching dataflow properties. |
23 | */ |
24 | #ifndef TVM_RELAY_DATAFLOW_MATCHER_H_ |
25 | #define TVM_RELAY_DATAFLOW_MATCHER_H_ |
26 | |
27 | #include <tvm/relay/dataflow_pattern.h> |
28 | #include <tvm/relay/dataflow_pattern_functor.h> |
29 | |
30 | #include <string> |
31 | #include <unordered_map> |
32 | #include <utility> |
33 | |
34 | namespace tvm { |
35 | namespace relay { |
36 | |
37 | class DFPatternCallback; |
38 | /*! |
39 | * \brief Base type of all dataflow pattern callbacks. |
40 | * \sa DFPatternCallback |
41 | */ |
42 | class DFPatternCallbackNode : public Object { |
43 | public: |
44 | /*! \brief Pattern this callback matches */ |
45 | DFPattern pattern; |
46 | /*! \brief Function to call when finding a matched expression */ |
47 | PackedFunc function; |
48 | /*! \brief Require InferType to be run before the callback */ |
49 | bool require_type; |
50 | /*! \brief Run the callback only once */ |
51 | bool rewrite_once; |
52 | |
53 | void VisitAttrs(tvm::AttrVisitor* v) { |
54 | v->Visit("pattern" , &pattern); |
55 | v->Visit("require_type" , &require_type); |
56 | v->Visit("rewrite_once" , &rewrite_once); |
57 | } |
58 | |
59 | static constexpr const char* _type_key = "DFPatternCallbackNode" ; |
60 | TVM_DECLARE_BASE_OBJECT_INFO(DFPatternCallbackNode, Object); |
61 | }; |
62 | |
63 | /*! |
64 | * \brief Managed reference to dataflow pattern callbacks. |
65 | * \sa DFPatternCallbackNode |
66 | */ |
67 | class DFPatternCallback : public ObjectRef { |
68 | public: |
69 | TVM_DLL DFPatternCallback(DFPattern pattern, PackedFunc callback, bool require_type, |
70 | bool rewrite_once = false); |
71 | TVM_DEFINE_OBJECT_REF_METHODS(DFPatternCallback, ObjectRef, DFPatternCallbackNode); |
72 | }; |
73 | |
74 | /*! |
75 | * \brief Determine if a pattern matches an expression |
76 | * |
77 | * \param pattern The pattern to match |
78 | * \param expr The expression to match |
79 | * |
80 | * \return Return true if the pattern and the expression match, return false otherwise. |
81 | */ |
82 | bool MatchPattern(DFPattern pattern, Expr expr); |
83 | |
84 | /*! |
85 | * \brief Rewrite an expression based on some number of DFPatternCallbacks |
86 | * |
87 | * \param callbacks An array of DFPatternCallback Nodes |
88 | * \param expr The expression to rewrite |
89 | * \param mod The module that associates with the expr |
90 | * |
91 | * \return Return An Expr with every match of the pattern inside the callbacks rewritten by the |
92 | * functions inside the callbacks |
93 | */ |
94 | Expr RewritePatterns(Array<DFPatternCallback> callbacks, Expr expr, IRModule mod = IRModule()); |
95 | |
96 | /*! |
97 | * \brief Partition all matches of a DFPattern inside an Expr into separate Function calls |
98 | * |
99 | * \param pattern The pattern to match |
100 | * \param expr The expression to patition |
101 | * \param attrs A set of parameter names and values to apply to the partitioned function |
102 | * \param check A callback function for checking more complicated properties of the matched |
103 | * expressions, returns true if the match is accepted and false otherwise |
104 | * |
105 | * \return Return the paritioned Expr. |
106 | */ |
107 | Expr PartitionPattern(DFPattern pattern, Expr expr, Map<String, ObjectRef> attrs, PackedFunc check); |
108 | |
109 | /*! |
110 | * \brief Infer the type of an expression. |
111 | * |
112 | * \param expr The expression to rewrite |
113 | * |
114 | * \return Return An Expr with unambiguous type information filled in, as well as it's |
115 | * checked type field populated with the result type. |
116 | * |
117 | */ |
118 | Expr InferType(const Expr& expr); |
119 | |
120 | } // namespace relay |
121 | } // namespace tvm |
122 | |
123 | #endif // TVM_RELAY_DATAFLOW_MATCHER_H_ |
124 | |