1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/relay/pattern_functor.h
22 * \brief A more powerful visitor on ADT patterns that enables defining
23 * arbitrary function signatures with type-based dispatch on first argument.
24 */
25#ifndef TVM_RELAY_PATTERN_FUNCTOR_H_
26#define TVM_RELAY_PATTERN_FUNCTOR_H_
27
28#include <tvm/node/functor.h>
29#include <tvm/relay/error.h>
30
31#include <string>
32#include <unordered_map>
33#include <utility>
34
35#include "./adt.h"
36#include "./expr.h"
37#include "./op.h"
38
39namespace tvm {
40namespace relay {
41
42/*!
43 * \brief A dynamical functor on ADT patterns that dispatches on its first argument.
44 * You can use this as a more powerful visitor, since it allows you to
45 * define the types of further arguments to VisitPattern.
46 *
47 * \sa tvm/ir_functor.h
48 *
49 * \tparam FType function signiture
50 * This type is only defined for FType with function signature R(const Pattern&,
51 * Args...)
52 */
53template <typename FType>
54class PatternFunctor;
55
56// functions to be overriden.
57#define PATTERN_FUNCTOR_DEFAULT \
58 { return VisitPatternDefault_(op, std::forward<Args>(args)...); }
59
60#define RELAY_PATTERN_FUNCTOR_DISPATCH(OP) \
61 vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \
62 return self->VisitPattern_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \
63 });
64
65template <typename R, typename... Args>
66class PatternFunctor<R(const Pattern& n, Args...)> {
67 private:
68 using TSelf = PatternFunctor<R(const Pattern& n, Args...)>;
69 using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
70
71 public:
72 /*! \brief the result type of this functor */
73 using result_type = R;
74 /*! \brief virtual destructor */
75 virtual ~PatternFunctor() {}
76 /*!
77 * \brief Same as call.
78 * \param n The expression node.
79 * \param args Additional arguments.
80 * \return The result of the call
81 */
82 R operator()(const Pattern& n, Args... args) {
83 return VisitPattern(n, std::forward<Args>(args)...);
84 }
85 /*!
86 * \brief The functor call.
87 * \param n The expression node.
88 * \param args Additional arguments.
89 * \return The result of the call
90 */
91 virtual R VisitPattern(const Pattern& n, Args... args) {
92 ICHECK(n.defined());
93 static FType vtable = InitVTable();
94 return vtable(n, this, std::forward<Args>(args)...);
95 }
96 // Functions that can be overriden by subclass
97 virtual R VisitPattern_(const PatternWildcardNode* op, Args... args) PATTERN_FUNCTOR_DEFAULT;
98 virtual R VisitPattern_(const PatternVarNode* op, Args... args) PATTERN_FUNCTOR_DEFAULT;
99 virtual R VisitPattern_(const PatternConstructorNode* op, Args... args) PATTERN_FUNCTOR_DEFAULT;
100 virtual R VisitPattern_(const PatternTupleNode* op, Args... args) PATTERN_FUNCTOR_DEFAULT;
101 virtual R VisitPatternDefault_(const Object* op, Args...) {
102 LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
103 throw;
104 }
105
106 private:
107 // initialize the vtable.
108 static FType InitVTable() {
109 FType vtable;
110 // Set dispatch
111 RELAY_PATTERN_FUNCTOR_DISPATCH(PatternWildcardNode);
112 RELAY_PATTERN_FUNCTOR_DISPATCH(PatternVarNode);
113 RELAY_PATTERN_FUNCTOR_DISPATCH(PatternConstructorNode);
114 RELAY_PATTERN_FUNCTOR_DISPATCH(PatternTupleNode);
115 return vtable;
116 }
117};
118
119/*! \brief A simple visitor wrapper around PatternFunctor.
120 *
121 * Exposes two visitors with default traversal strategies, one
122 * which doesn't compute a result but can mutate internal state,
123 * and another which functionally builds a new pattern.
124 */
125class PatternVisitor : public ::tvm::relay::PatternFunctor<void(const Pattern& n)> {
126 public:
127 void VisitPattern_(const PatternWildcardNode* op) override;
128 void VisitPattern_(const PatternVarNode* op) override;
129 void VisitPattern_(const PatternConstructorNode* op) override;
130 void VisitPattern_(const PatternTupleNode* op) override;
131 virtual void VisitType(const Type& t);
132 virtual void VisitVar(const Var& v);
133 virtual void VisitConstructor(const Constructor& c);
134};
135
136/*! \brief A wrapper around ExprFunctor which functionally updates the AST.
137 *
138 * ExprMutator uses memoization and self return in order to amortize
139 * the cost of using functional updates.
140 */
141class PatternMutator : public ::tvm::relay::PatternFunctor<Pattern(const Pattern&)> {
142 public:
143 Pattern Mutate(const Pattern& pat);
144 Pattern VisitPattern_(const PatternWildcardNode* op) override;
145 Pattern VisitPattern_(const PatternVarNode* op) override;
146 Pattern VisitPattern_(const PatternConstructorNode* op) override;
147 Pattern VisitPattern_(const PatternTupleNode* op) override;
148 /*! \brief Used to visit the types inside of patterns.
149 *
150 * Can be overloaded to transform the types in arbitrary
151 * ways, one way would be to define a sub-class of type
152 * visitor for types which transform them appropriately.
153 */
154 virtual Type VisitType(const Type& t);
155 /*! \brief Used to visit the vars inside of patterns. */
156 virtual Var VisitVar(const Var& v);
157 /*! \brief Used to visit the vars inside of patterns. */
158 virtual Constructor VisitConstructor(const Constructor& c);
159
160 private:
161 std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual> var_map_;
162};
163
164} // namespace relay
165} // namespace tvm
166#endif // TVM_RELAY_PATTERN_FUNCTOR_H_
167