1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/relay/dataflow_pattern_functor.h |
22 | * \brief A set of passes for operating on pattern graphs. |
23 | */ |
24 | #ifndef TVM_RELAY_DATAFLOW_PATTERN_FUNCTOR_H_ |
25 | #define TVM_RELAY_DATAFLOW_PATTERN_FUNCTOR_H_ |
26 | |
27 | #include <tvm/relay/dataflow_pattern.h> |
28 | |
29 | #include <unordered_set> |
30 | #include <utility> |
31 | |
32 | namespace tvm { |
33 | namespace relay { |
34 | |
35 | /*! |
36 | * \brief A dynamical functor that dispatches on in the first DFPattern argument. |
37 | * |
38 | * \tparam FType function signature |
39 | * This type is only defined for FType with function signature R(const DFPattern&, |
40 | * Args...) |
41 | */ |
42 | template <typename FType> |
43 | class DFPatternFunctor; |
44 | |
45 | // functions to be overriden. |
46 | #define DFPATTERN_FUNCTOR_DEFAULT \ |
47 | { return VisitDFPatternDefault_(op, std::forward<Args>(args)...); } |
48 | |
49 | #define RELAY_DFPATTERN_FUNCTOR_DISPATCH(OP) \ |
50 | vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \ |
51 | return self->VisitDFPattern_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \ |
52 | }); |
53 | |
54 | template <typename R, typename... Args> |
55 | class DFPatternFunctor<R(const DFPattern& n, Args...)> { |
56 | private: |
57 | using TSelf = DFPatternFunctor<R(const DFPattern& n, Args...)>; |
58 | using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>; |
59 | |
60 | public: |
61 | /*! \brief virtual destructor */ |
62 | virtual ~DFPatternFunctor() {} |
63 | /*! |
64 | * \brief Same as call. |
65 | * \param n The expression node. |
66 | * \param args Additional arguments. |
67 | * \return The result of the call |
68 | */ |
69 | R operator()(const DFPattern& n, Args... args) { |
70 | return VisitDFPattern(n, std::forward<Args>(args)...); |
71 | } |
72 | /*! |
73 | * \brief The functor call. |
74 | * \param n The expression node. |
75 | * \param args Additional arguments. |
76 | * \return The result of the call |
77 | */ |
78 | virtual R VisitDFPattern(const DFPattern& n, Args... args) { |
79 | ICHECK(n.defined()); |
80 | static FType vtable = InitVTable(); |
81 | return vtable(n, this, std::forward<Args>(args)...); |
82 | } |
83 | // Functions that can be overriden by subclass |
84 | virtual R VisitDFPattern_(const AltPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; |
85 | virtual R VisitDFPattern_(const AttrPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; |
86 | virtual R VisitDFPattern_(const CallPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; |
87 | virtual R VisitDFPattern_(const ConstantPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; |
88 | virtual R VisitDFPattern_(const DataTypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; |
89 | virtual R VisitDFPattern_(const DominatorPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; |
90 | virtual R VisitDFPattern_(const ExprPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; |
91 | virtual R VisitDFPattern_(const FunctionPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; |
92 | virtual R VisitDFPattern_(const IfPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; |
93 | virtual R VisitDFPattern_(const LetPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; |
94 | virtual R VisitDFPattern_(const ShapePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; |
95 | virtual R VisitDFPattern_(const TupleGetItemPatternNode* op, |
96 | Args... args) DFPATTERN_FUNCTOR_DEFAULT; |
97 | virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; |
98 | virtual R VisitDFPattern_(const TypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; |
99 | virtual R VisitDFPattern_(const VarPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; |
100 | virtual R VisitDFPattern_(const WildcardPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; |
101 | virtual R VisitDFPatternDefault_(const Object* op, Args...) { |
102 | LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); |
103 | throw; |
104 | } |
105 | |
106 | private: |
107 | // initialize the vtable. |
108 | static FType InitVTable() { |
109 | FType vtable; |
110 | // Set dispatch |
111 | RELAY_DFPATTERN_FUNCTOR_DISPATCH(AltPatternNode); |
112 | RELAY_DFPATTERN_FUNCTOR_DISPATCH(AttrPatternNode); |
113 | RELAY_DFPATTERN_FUNCTOR_DISPATCH(CallPatternNode); |
114 | RELAY_DFPATTERN_FUNCTOR_DISPATCH(ConstantPatternNode); |
115 | RELAY_DFPATTERN_FUNCTOR_DISPATCH(DataTypePatternNode); |
116 | RELAY_DFPATTERN_FUNCTOR_DISPATCH(DominatorPatternNode); |
117 | RELAY_DFPATTERN_FUNCTOR_DISPATCH(ExprPatternNode); |
118 | RELAY_DFPATTERN_FUNCTOR_DISPATCH(FunctionPatternNode); |
119 | RELAY_DFPATTERN_FUNCTOR_DISPATCH(IfPatternNode); |
120 | RELAY_DFPATTERN_FUNCTOR_DISPATCH(LetPatternNode); |
121 | RELAY_DFPATTERN_FUNCTOR_DISPATCH(ShapePatternNode); |
122 | RELAY_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode); |
123 | RELAY_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode); |
124 | RELAY_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode); |
125 | RELAY_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode); |
126 | RELAY_DFPATTERN_FUNCTOR_DISPATCH(WildcardPatternNode); |
127 | return vtable; |
128 | } |
129 | }; |
130 | |
131 | /*! |
132 | * \brief A simple visitor wrapper around DFPatternFunctor. |
133 | * Recursively visit the content. |
134 | * |
135 | * DFPatternVisitor treats the Pattern as dataflow graph,and only visit each Expr node once. |
136 | */ |
137 | class DFPatternVisitor : public DFPatternFunctor<void(const DFPattern&)> { |
138 | public: |
139 | void VisitDFPattern(const DFPattern& pattern) override; |
140 | void VisitDFPattern_(const AltPatternNode* op) override; |
141 | void VisitDFPattern_(const AttrPatternNode* op) override; |
142 | void VisitDFPattern_(const CallPatternNode* op) override; |
143 | void VisitDFPattern_(const ConstantPatternNode* op) override; |
144 | void VisitDFPattern_(const DataTypePatternNode* op) override; |
145 | void VisitDFPattern_(const DominatorPatternNode* op) override; |
146 | void VisitDFPattern_(const ExprPatternNode* op) override; |
147 | void VisitDFPattern_(const FunctionPatternNode* op) override; |
148 | void VisitDFPattern_(const IfPatternNode* op) override; |
149 | void VisitDFPattern_(const LetPatternNode* op) override; |
150 | void VisitDFPattern_(const ShapePatternNode* op) override; |
151 | void VisitDFPattern_(const TupleGetItemPatternNode* op) override; |
152 | void VisitDFPattern_(const TuplePatternNode* op) override; |
153 | void VisitDFPattern_(const TypePatternNode* op) override; |
154 | void VisitDFPattern_(const VarPatternNode* op) override; |
155 | void VisitDFPattern_(const WildcardPatternNode* op) override; |
156 | |
157 | protected: |
158 | // set of already-visited nodes |
159 | std::unordered_set<const Object*> visited_; |
160 | }; |
161 | |
162 | } // namespace relay |
163 | } // namespace tvm |
164 | #endif // TVM_RELAY_DATAFLOW_PATTERN_FUNCTOR_H_ |
165 | |