1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/tvm/relay/dataflow_matcher_impl.h
22 * \brief The auxiliary data structure for dataflow matcher.
23 */
24#ifndef TVM_RELAY_IR_DATAFLOW_MATCHER_IMPL_H_
25#define TVM_RELAY_IR_DATAFLOW_MATCHER_IMPL_H_
26
27#include <tvm/relay/dataflow_matcher.h>
28#include <tvm/relay/dataflow_pattern.h>
29#include <tvm/relay/dataflow_pattern_functor.h>
30#include <tvm/relay/expr_functor.h>
31
32#include <memory>
33#include <string>
34#include <unordered_map>
35#include <vector>
36
37#include "indexed_graph.h"
38
39namespace tvm {
40namespace relay {
41
42class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
43 public:
44 explicit DFPatternMatcher(const IndexedGraph<Expr>* expr_graph) : expr_graph_(expr_graph) {}
45 bool Match(const DFPattern& pattern, const Expr& expr);
46 Map<DFPattern, Array<Expr>> GetMemo() { return Map<DFPattern, Array<Expr>>(memo_); }
47
48 const IndexedGraph<Expr>::Node* expr_to_node(const Expr& expr) const {
49 return expr_graph_->item_to_node(expr);
50 }
51 const IndexedGraph<Expr>::Node* index_to_node(size_t index) const {
52 return expr_graph_->index_to_node(index);
53 }
54 size_t size() const { return expr_graph_->size(); }
55 const std::unordered_map<DFPattern, Array<Expr>, ObjectPtrHash, ObjectPtrEqual>& memo() const {
56 return memo_;
57 }
58 const IndexedGraph<Expr>& expr_graph() const { return *expr_graph_; }
59
60 protected:
61 bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
62 bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
63 bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
64 bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
65 bool VisitDFPattern_(const ConstantPatternNode* op, const Expr& expr) override;
66 bool VisitDFPattern_(const DataTypePatternNode* op, const Expr& expr) override;
67 bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
68 bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
69 bool VisitDFPattern_(const FunctionPatternNode* op, const Expr& expr) override;
70 bool VisitDFPattern_(const IfPatternNode* op, const Expr& expr) override;
71 bool VisitDFPattern_(const LetPatternNode* op, const Expr& expr) override;
72 bool VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) override;
73 bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
74 bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
75 bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
76 bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
77 bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
78
79 void ClearMap(size_t watermark);
80 bool MatchesPath(const DominatorPatternNode* op, const Expr& expr);
81 bool DominatesParent(const DominatorPatternNode* op, const Expr& expr);
82
83 const IndexedGraph<Expr>* expr_graph_;
84 std::unordered_map<DFPattern, Array<Expr>, ObjectPtrHash, ObjectPtrEqual> memo_;
85 std::vector<DFPattern> matched_nodes_;
86 bool memoize_ = true;
87};
88
89/*!
90 * \brief PatternGrouper does pre-rewriting pattern matching and analysis
91 *
92 * This class creates a number of groups of matched expressions, ensures they don't overlap, and
93 * returns them to the caller for post-analysis rewriting.
94 *
95 * This is primarily needed to support the post-dominator analysis required for dominator pattern
96 * matching.
97 */
98class PatternGrouper {
99 public:
100 /*! \brief Internal Group class for storing analysis */
101 struct Group {
102 Expr root_node;
103 int gid;
104 Map<DFPattern, Array<Expr>> matched_nodes;
105 std::string name;
106 Function function;
107 Array<Expr> args;
108 };
109
110 /*! \brief Return the group assignments of expressions */
111 inline const std::unordered_map<Expr, int, ObjectPtrHash, ObjectPtrEqual>& GetGIDAssignments() {
112 return gid_assignments_;
113 }
114 /*! \brief Group expressions that match the pattern */
115 const std::unordered_map<int, Group>& GroupMatches(const DFPattern& pattern, const Expr& pre);
116
117 protected:
118 /*! \brief Iteratively traverse the Expression in pre-order to find subgraphs
119 *
120 * If we traverse the graph in post-order, we can run into situtations where a small subgraph will
121 * match the pattern. Due to options like AltPattern, a larger subgraph with more nodes later in
122 * the graph may also match the pattern. With post-order traversal, we mark the smaller subgraph
123 * as matched and fail to catch the larger subgraph. This problem is fixed by using pre-order
124 * traversal.
125 */
126 void VisitExprs();
127
128 /*! \brief Create a group based on a matched expression */
129 void CreateGroup(const Expr& expr);
130
131 /*! \brief EmbedConst implements rules for embedding constants into partitioned functions or
132 * lifting them into the function arguments.
133 *
134 * The rules depend on what pattern the ConstantNode matched.
135 *
136 * The basic rules are:
137 * If the constant matches ExprPattern(relay.const(*)) or a ConstantPattern(), embed the constant
138 * in the partitioned function. If the constant matched an AltPattern, recursively check the
139 * matched side of the pattern. For any other matching pattern (i.e, wildcard, VarPattern, etc),
140 * lift the constant into the arguments of the partitioned function.
141 */
142 bool EmbedConst(const Expr& expr, const DFPattern pattern);
143 // Internal State
144 DFPattern pattern_;
145 std::unordered_map<int, Group> groups_;
146 std::unordered_map<Expr, int, ObjectPtrHash, ObjectPtrEqual> gid_assignments_;
147 DFPatternMatcher* matcher_ = nullptr;
148 std::unique_ptr<IndexedGraph<DFPattern>> pattern_graph_;
149 int gid_ = 0;
150 int graph_number_ = 0;
151};
152
153/*!
154 * \brief PatternRewriter rewrites the expression by finding matches and allowing user callback
155 * function to rewrite those matches
156 *
157 * The class uses PatternGrouper to support the dominator pattern.
158 */
159class PatternRewriter : protected MixedModeMutator {
160 public:
161 explicit PatternRewriter(IRModule mod) : mod_(mod) {}
162 /*! \brief Rewrite can take a number of callbacks and will repeatedly rewrite the graph with the
163 * callbacks until it stops changing */
164 virtual Expr Rewrite(const Array<DFPatternCallback>& callbacks, const Expr& pre);
165
166 protected:
167 virtual Expr DispatchVisitExpr(const Expr& pre);
168
169 IRModule mod_;
170 DFPatternCallback callback_;
171 std::unordered_map<int, PatternGrouper::Group> groups_;
172 std::unordered_map<Expr, int, ObjectPtrHash, ObjectPtrEqual> gid_assignments_;
173};
174
175} // namespace relay
176} // namespace tvm
177
178#endif // TVM_RELAY_IR_DATAFLOW_MATCHER_IMPL_H_
179