1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/relay/dataflow_pattern_functor.h
22 * \brief A set of passes for operating on pattern graphs.
23 */
24#ifndef TVM_RELAY_DATAFLOW_PATTERN_FUNCTOR_H_
25#define TVM_RELAY_DATAFLOW_PATTERN_FUNCTOR_H_
26
27#include <tvm/relay/dataflow_pattern.h>
28
29#include <unordered_set>
30#include <utility>
31
32namespace tvm {
33namespace relay {
34
35/*!
36 * \brief A dynamical functor that dispatches on in the first DFPattern argument.
37 *
38 * \tparam FType function signature
39 * This type is only defined for FType with function signature R(const DFPattern&,
40 * Args...)
41 */
42template <typename FType>
43class DFPatternFunctor;
44
45// functions to be overriden.
46#define DFPATTERN_FUNCTOR_DEFAULT \
47 { return VisitDFPatternDefault_(op, std::forward<Args>(args)...); }
48
49#define RELAY_DFPATTERN_FUNCTOR_DISPATCH(OP) \
50 vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \
51 return self->VisitDFPattern_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \
52 });
53
54template <typename R, typename... Args>
55class DFPatternFunctor<R(const DFPattern& n, Args...)> {
56 private:
57 using TSelf = DFPatternFunctor<R(const DFPattern& n, Args...)>;
58 using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
59
60 public:
61 /*! \brief virtual destructor */
62 virtual ~DFPatternFunctor() {}
63 /*!
64 * \brief Same as call.
65 * \param n The expression node.
66 * \param args Additional arguments.
67 * \return The result of the call
68 */
69 R operator()(const DFPattern& n, Args... args) {
70 return VisitDFPattern(n, std::forward<Args>(args)...);
71 }
72 /*!
73 * \brief The functor call.
74 * \param n The expression node.
75 * \param args Additional arguments.
76 * \return The result of the call
77 */
78 virtual R VisitDFPattern(const DFPattern& n, Args... args) {
79 ICHECK(n.defined());
80 static FType vtable = InitVTable();
81 return vtable(n, this, std::forward<Args>(args)...);
82 }
83 // Functions that can be overriden by subclass
84 virtual R VisitDFPattern_(const AltPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
85 virtual R VisitDFPattern_(const AttrPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
86 virtual R VisitDFPattern_(const CallPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
87 virtual R VisitDFPattern_(const ConstantPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
88 virtual R VisitDFPattern_(const DataTypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
89 virtual R VisitDFPattern_(const DominatorPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
90 virtual R VisitDFPattern_(const ExprPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
91 virtual R VisitDFPattern_(const FunctionPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
92 virtual R VisitDFPattern_(const IfPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
93 virtual R VisitDFPattern_(const LetPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
94 virtual R VisitDFPattern_(const ShapePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
95 virtual R VisitDFPattern_(const TupleGetItemPatternNode* op,
96 Args... args) DFPATTERN_FUNCTOR_DEFAULT;
97 virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
98 virtual R VisitDFPattern_(const TypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
99 virtual R VisitDFPattern_(const VarPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
100 virtual R VisitDFPattern_(const WildcardPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
101 virtual R VisitDFPatternDefault_(const Object* op, Args...) {
102 LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
103 throw;
104 }
105
106 private:
107 // initialize the vtable.
108 static FType InitVTable() {
109 FType vtable;
110 // Set dispatch
111 RELAY_DFPATTERN_FUNCTOR_DISPATCH(AltPatternNode);
112 RELAY_DFPATTERN_FUNCTOR_DISPATCH(AttrPatternNode);
113 RELAY_DFPATTERN_FUNCTOR_DISPATCH(CallPatternNode);
114 RELAY_DFPATTERN_FUNCTOR_DISPATCH(ConstantPatternNode);
115 RELAY_DFPATTERN_FUNCTOR_DISPATCH(DataTypePatternNode);
116 RELAY_DFPATTERN_FUNCTOR_DISPATCH(DominatorPatternNode);
117 RELAY_DFPATTERN_FUNCTOR_DISPATCH(ExprPatternNode);
118 RELAY_DFPATTERN_FUNCTOR_DISPATCH(FunctionPatternNode);
119 RELAY_DFPATTERN_FUNCTOR_DISPATCH(IfPatternNode);
120 RELAY_DFPATTERN_FUNCTOR_DISPATCH(LetPatternNode);
121 RELAY_DFPATTERN_FUNCTOR_DISPATCH(ShapePatternNode);
122 RELAY_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode);
123 RELAY_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode);
124 RELAY_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode);
125 RELAY_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode);
126 RELAY_DFPATTERN_FUNCTOR_DISPATCH(WildcardPatternNode);
127 return vtable;
128 }
129};
130
131/*!
132 * \brief A simple visitor wrapper around DFPatternFunctor.
133 * Recursively visit the content.
134 *
135 * DFPatternVisitor treats the Pattern as dataflow graph,and only visit each Expr node once.
136 */
137class DFPatternVisitor : public DFPatternFunctor<void(const DFPattern&)> {
138 public:
139 void VisitDFPattern(const DFPattern& pattern) override;
140 void VisitDFPattern_(const AltPatternNode* op) override;
141 void VisitDFPattern_(const AttrPatternNode* op) override;
142 void VisitDFPattern_(const CallPatternNode* op) override;
143 void VisitDFPattern_(const ConstantPatternNode* op) override;
144 void VisitDFPattern_(const DataTypePatternNode* op) override;
145 void VisitDFPattern_(const DominatorPatternNode* op) override;
146 void VisitDFPattern_(const ExprPatternNode* op) override;
147 void VisitDFPattern_(const FunctionPatternNode* op) override;
148 void VisitDFPattern_(const IfPatternNode* op) override;
149 void VisitDFPattern_(const LetPatternNode* op) override;
150 void VisitDFPattern_(const ShapePatternNode* op) override;
151 void VisitDFPattern_(const TupleGetItemPatternNode* op) override;
152 void VisitDFPattern_(const TuplePatternNode* op) override;
153 void VisitDFPattern_(const TypePatternNode* op) override;
154 void VisitDFPattern_(const VarPatternNode* op) override;
155 void VisitDFPattern_(const WildcardPatternNode* op) override;
156
157 protected:
158 // set of already-visited nodes
159 std::unordered_set<const Object*> visited_;
160};
161
162} // namespace relay
163} // namespace tvm
164#endif // TVM_RELAY_DATAFLOW_PATTERN_FUNCTOR_H_
165