1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/relay/pattern_functor.h |
22 | * \brief A more powerful visitor on ADT patterns that enables defining |
23 | * arbitrary function signatures with type-based dispatch on first argument. |
24 | */ |
25 | #ifndef TVM_RELAY_PATTERN_FUNCTOR_H_ |
26 | #define TVM_RELAY_PATTERN_FUNCTOR_H_ |
27 | |
28 | #include <tvm/node/functor.h> |
29 | #include <tvm/relay/error.h> |
30 | |
31 | #include <string> |
32 | #include <unordered_map> |
33 | #include <utility> |
34 | |
35 | #include "./adt.h" |
36 | #include "./expr.h" |
37 | #include "./op.h" |
38 | |
39 | namespace tvm { |
40 | namespace relay { |
41 | |
42 | /*! |
43 | * \brief A dynamical functor on ADT patterns that dispatches on its first argument. |
44 | * You can use this as a more powerful visitor, since it allows you to |
45 | * define the types of further arguments to VisitPattern. |
46 | * |
47 | * \sa tvm/ir_functor.h |
48 | * |
49 | * \tparam FType function signiture |
50 | * This type is only defined for FType with function signature R(const Pattern&, |
51 | * Args...) |
52 | */ |
53 | template <typename FType> |
54 | class PatternFunctor; |
55 | |
56 | // functions to be overriden. |
57 | #define PATTERN_FUNCTOR_DEFAULT \ |
58 | { return VisitPatternDefault_(op, std::forward<Args>(args)...); } |
59 | |
60 | #define RELAY_PATTERN_FUNCTOR_DISPATCH(OP) \ |
61 | vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \ |
62 | return self->VisitPattern_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \ |
63 | }); |
64 | |
65 | template <typename R, typename... Args> |
66 | class PatternFunctor<R(const Pattern& n, Args...)> { |
67 | private: |
68 | using TSelf = PatternFunctor<R(const Pattern& n, Args...)>; |
69 | using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>; |
70 | |
71 | public: |
72 | /*! \brief the result type of this functor */ |
73 | using result_type = R; |
74 | /*! \brief virtual destructor */ |
75 | virtual ~PatternFunctor() {} |
76 | /*! |
77 | * \brief Same as call. |
78 | * \param n The expression node. |
79 | * \param args Additional arguments. |
80 | * \return The result of the call |
81 | */ |
82 | R operator()(const Pattern& n, Args... args) { |
83 | return VisitPattern(n, std::forward<Args>(args)...); |
84 | } |
85 | /*! |
86 | * \brief The functor call. |
87 | * \param n The expression node. |
88 | * \param args Additional arguments. |
89 | * \return The result of the call |
90 | */ |
91 | virtual R VisitPattern(const Pattern& n, Args... args) { |
92 | ICHECK(n.defined()); |
93 | static FType vtable = InitVTable(); |
94 | return vtable(n, this, std::forward<Args>(args)...); |
95 | } |
96 | // Functions that can be overriden by subclass |
97 | virtual R VisitPattern_(const PatternWildcardNode* op, Args... args) PATTERN_FUNCTOR_DEFAULT; |
98 | virtual R VisitPattern_(const PatternVarNode* op, Args... args) PATTERN_FUNCTOR_DEFAULT; |
99 | virtual R VisitPattern_(const PatternConstructorNode* op, Args... args) PATTERN_FUNCTOR_DEFAULT; |
100 | virtual R VisitPattern_(const PatternTupleNode* op, Args... args) PATTERN_FUNCTOR_DEFAULT; |
101 | virtual R VisitPatternDefault_(const Object* op, Args...) { |
102 | LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); |
103 | throw; |
104 | } |
105 | |
106 | private: |
107 | // initialize the vtable. |
108 | static FType InitVTable() { |
109 | FType vtable; |
110 | // Set dispatch |
111 | RELAY_PATTERN_FUNCTOR_DISPATCH(PatternWildcardNode); |
112 | RELAY_PATTERN_FUNCTOR_DISPATCH(PatternVarNode); |
113 | RELAY_PATTERN_FUNCTOR_DISPATCH(PatternConstructorNode); |
114 | RELAY_PATTERN_FUNCTOR_DISPATCH(PatternTupleNode); |
115 | return vtable; |
116 | } |
117 | }; |
118 | |
119 | /*! \brief A simple visitor wrapper around PatternFunctor. |
120 | * |
121 | * Exposes two visitors with default traversal strategies, one |
122 | * which doesn't compute a result but can mutate internal state, |
123 | * and another which functionally builds a new pattern. |
124 | */ |
125 | class PatternVisitor : public ::tvm::relay::PatternFunctor<void(const Pattern& n)> { |
126 | public: |
127 | void VisitPattern_(const PatternWildcardNode* op) override; |
128 | void VisitPattern_(const PatternVarNode* op) override; |
129 | void VisitPattern_(const PatternConstructorNode* op) override; |
130 | void VisitPattern_(const PatternTupleNode* op) override; |
131 | virtual void VisitType(const Type& t); |
132 | virtual void VisitVar(const Var& v); |
133 | virtual void VisitConstructor(const Constructor& c); |
134 | }; |
135 | |
136 | /*! \brief A wrapper around ExprFunctor which functionally updates the AST. |
137 | * |
138 | * ExprMutator uses memoization and self return in order to amortize |
139 | * the cost of using functional updates. |
140 | */ |
141 | class PatternMutator : public ::tvm::relay::PatternFunctor<Pattern(const Pattern&)> { |
142 | public: |
143 | Pattern Mutate(const Pattern& pat); |
144 | Pattern VisitPattern_(const PatternWildcardNode* op) override; |
145 | Pattern VisitPattern_(const PatternVarNode* op) override; |
146 | Pattern VisitPattern_(const PatternConstructorNode* op) override; |
147 | Pattern VisitPattern_(const PatternTupleNode* op) override; |
148 | /*! \brief Used to visit the types inside of patterns. |
149 | * |
150 | * Can be overloaded to transform the types in arbitrary |
151 | * ways, one way would be to define a sub-class of type |
152 | * visitor for types which transform them appropriately. |
153 | */ |
154 | virtual Type VisitType(const Type& t); |
155 | /*! \brief Used to visit the vars inside of patterns. */ |
156 | virtual Var VisitVar(const Var& v); |
157 | /*! \brief Used to visit the vars inside of patterns. */ |
158 | virtual Constructor VisitConstructor(const Constructor& c); |
159 | |
160 | private: |
161 | std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual> var_map_; |
162 | }; |
163 | |
164 | } // namespace relay |
165 | } // namespace tvm |
166 | #endif // TVM_RELAY_PATTERN_FUNCTOR_H_ |
167 | |