1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/tvm/relay/dataflow_matcher_impl.h |
22 | * \brief The auxiliary data structure for dataflow matcher. |
23 | */ |
24 | #ifndef TVM_RELAY_IR_DATAFLOW_MATCHER_IMPL_H_ |
25 | #define TVM_RELAY_IR_DATAFLOW_MATCHER_IMPL_H_ |
26 | |
27 | #include <tvm/relay/dataflow_matcher.h> |
28 | #include <tvm/relay/dataflow_pattern.h> |
29 | #include <tvm/relay/dataflow_pattern_functor.h> |
30 | #include <tvm/relay/expr_functor.h> |
31 | |
32 | #include <memory> |
33 | #include <string> |
34 | #include <unordered_map> |
35 | #include <vector> |
36 | |
37 | #include "indexed_graph.h" |
38 | |
39 | namespace tvm { |
40 | namespace relay { |
41 | |
42 | class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> { |
43 | public: |
44 | explicit DFPatternMatcher(const IndexedGraph<Expr>* expr_graph) : expr_graph_(expr_graph) {} |
45 | bool Match(const DFPattern& pattern, const Expr& expr); |
46 | Map<DFPattern, Array<Expr>> GetMemo() { return Map<DFPattern, Array<Expr>>(memo_); } |
47 | |
48 | const IndexedGraph<Expr>::Node* expr_to_node(const Expr& expr) const { |
49 | return expr_graph_->item_to_node(expr); |
50 | } |
51 | const IndexedGraph<Expr>::Node* index_to_node(size_t index) const { |
52 | return expr_graph_->index_to_node(index); |
53 | } |
54 | size_t size() const { return expr_graph_->size(); } |
55 | const std::unordered_map<DFPattern, Array<Expr>, ObjectPtrHash, ObjectPtrEqual>& memo() const { |
56 | return memo_; |
57 | } |
58 | const IndexedGraph<Expr>& expr_graph() const { return *expr_graph_; } |
59 | |
60 | protected: |
61 | bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override; |
62 | bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override; |
63 | bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override; |
64 | bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override; |
65 | bool VisitDFPattern_(const ConstantPatternNode* op, const Expr& expr) override; |
66 | bool VisitDFPattern_(const DataTypePatternNode* op, const Expr& expr) override; |
67 | bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override; |
68 | bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override; |
69 | bool VisitDFPattern_(const FunctionPatternNode* op, const Expr& expr) override; |
70 | bool VisitDFPattern_(const IfPatternNode* op, const Expr& expr) override; |
71 | bool VisitDFPattern_(const LetPatternNode* op, const Expr& expr) override; |
72 | bool VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) override; |
73 | bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override; |
74 | bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override; |
75 | bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override; |
76 | bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override; |
77 | bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override; |
78 | |
79 | void ClearMap(size_t watermark); |
80 | bool MatchesPath(const DominatorPatternNode* op, const Expr& expr); |
81 | bool DominatesParent(const DominatorPatternNode* op, const Expr& expr); |
82 | |
83 | const IndexedGraph<Expr>* expr_graph_; |
84 | std::unordered_map<DFPattern, Array<Expr>, ObjectPtrHash, ObjectPtrEqual> memo_; |
85 | std::vector<DFPattern> matched_nodes_; |
86 | bool memoize_ = true; |
87 | }; |
88 | |
89 | /*! |
90 | * \brief PatternGrouper does pre-rewriting pattern matching and analysis |
91 | * |
92 | * This class creates a number of groups of matched expressions, ensures they don't overlap, and |
93 | * returns them to the caller for post-analysis rewriting. |
94 | * |
95 | * This is primarily needed to support the post-dominator analysis required for dominator pattern |
96 | * matching. |
97 | */ |
98 | class PatternGrouper { |
99 | public: |
100 | /*! \brief Internal Group class for storing analysis */ |
101 | struct Group { |
102 | Expr root_node; |
103 | int gid; |
104 | Map<DFPattern, Array<Expr>> matched_nodes; |
105 | std::string name; |
106 | Function function; |
107 | Array<Expr> args; |
108 | }; |
109 | |
110 | /*! \brief Return the group assignments of expressions */ |
111 | inline const std::unordered_map<Expr, int, ObjectPtrHash, ObjectPtrEqual>& GetGIDAssignments() { |
112 | return gid_assignments_; |
113 | } |
114 | /*! \brief Group expressions that match the pattern */ |
115 | const std::unordered_map<int, Group>& GroupMatches(const DFPattern& pattern, const Expr& pre); |
116 | |
117 | protected: |
118 | /*! \brief Iteratively traverse the Expression in pre-order to find subgraphs |
119 | * |
120 | * If we traverse the graph in post-order, we can run into situtations where a small subgraph will |
121 | * match the pattern. Due to options like AltPattern, a larger subgraph with more nodes later in |
122 | * the graph may also match the pattern. With post-order traversal, we mark the smaller subgraph |
123 | * as matched and fail to catch the larger subgraph. This problem is fixed by using pre-order |
124 | * traversal. |
125 | */ |
126 | void VisitExprs(); |
127 | |
128 | /*! \brief Create a group based on a matched expression */ |
129 | void CreateGroup(const Expr& expr); |
130 | |
131 | /*! \brief EmbedConst implements rules for embedding constants into partitioned functions or |
132 | * lifting them into the function arguments. |
133 | * |
134 | * The rules depend on what pattern the ConstantNode matched. |
135 | * |
136 | * The basic rules are: |
137 | * If the constant matches ExprPattern(relay.const(*)) or a ConstantPattern(), embed the constant |
138 | * in the partitioned function. If the constant matched an AltPattern, recursively check the |
139 | * matched side of the pattern. For any other matching pattern (i.e, wildcard, VarPattern, etc), |
140 | * lift the constant into the arguments of the partitioned function. |
141 | */ |
142 | bool EmbedConst(const Expr& expr, const DFPattern pattern); |
143 | // Internal State |
144 | DFPattern pattern_; |
145 | std::unordered_map<int, Group> groups_; |
146 | std::unordered_map<Expr, int, ObjectPtrHash, ObjectPtrEqual> gid_assignments_; |
147 | DFPatternMatcher* matcher_ = nullptr; |
148 | std::unique_ptr<IndexedGraph<DFPattern>> pattern_graph_; |
149 | int gid_ = 0; |
150 | int graph_number_ = 0; |
151 | }; |
152 | |
153 | /*! |
154 | * \brief PatternRewriter rewrites the expression by finding matches and allowing user callback |
155 | * function to rewrite those matches |
156 | * |
157 | * The class uses PatternGrouper to support the dominator pattern. |
158 | */ |
159 | class PatternRewriter : protected MixedModeMutator { |
160 | public: |
161 | explicit PatternRewriter(IRModule mod) : mod_(mod) {} |
162 | /*! \brief Rewrite can take a number of callbacks and will repeatedly rewrite the graph with the |
163 | * callbacks until it stops changing */ |
164 | virtual Expr Rewrite(const Array<DFPatternCallback>& callbacks, const Expr& pre); |
165 | |
166 | protected: |
167 | virtual Expr DispatchVisitExpr(const Expr& pre); |
168 | |
169 | IRModule mod_; |
170 | DFPatternCallback callback_; |
171 | std::unordered_map<int, PatternGrouper::Group> groups_; |
172 | std::unordered_map<Expr, int, ObjectPtrHash, ObjectPtrEqual> gid_assignments_; |
173 | }; |
174 | |
175 | } // namespace relay |
176 | } // namespace tvm |
177 | |
178 | #endif // TVM_RELAY_IR_DATAFLOW_MATCHER_IMPL_H_ |
179 | |