1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/tir/expr_functor.h
22 *
23 * \brief Functors for tir expressions.
24 */
25#ifndef TVM_TIR_EXPR_FUNCTOR_H_
26#define TVM_TIR_EXPR_FUNCTOR_H_
27
28#include <tvm/node/functor.h>
29#include <tvm/tir/expr.h>
30
31#include <utility>
32
33namespace tvm {
34namespace tir {
35
36/*!
37 * \brief A dynamical functor that dispatches on in the first Expr argument.
38 * You can use this as a more powerful Visitor, since it allows you to
39 * define function signatures of Visit Function.
40 *
41 * This helps you to avoid to book-keep return value of Visitor via state,
42 * which can cause bugs easily when state is incorrectly maintained.
43 *
44 * \code
45 * // A functor that set variable to b. and calculate results.
46 * class MyExprFunctor
47 * : public tir::ExprFunctor<int(const Expr&, int)> {
48 * public:
49 * int VisitExpr_(const Variable* op, int b) final {
50 * return b;
51 * }
52 * int VisitExpr_(const IntImm* op, int b) final {
53 * return op->value;
54 * }
55 * int VisitExpr_(const Add* op, int b) final {
56 * return Visit(op->a, b) + Visit(op->b, b);
57 * }
58 * };
59 * MyExprFunctor f;
60 * Var x("x");
61 * ICHECK_EQ(f(x + 1, 2), 3);
62 * \endcode
63 *
64 * \note Why do we need this more powerful Functor:
65 *
66 * We often need to implement a transformer tasks.
67 * Say we want to take Expr and transform it to some analysis result,
68 * This easily be done incorrectly using plain Visitor. See IRVisitor's
69 * document for possible error cases.
70 *
71 * \tparam FType function signiture
72 * This type if only defined for FType with function signiture R(const Expr&, Args...)
73 */
74template <typename FType>
75class ExprFunctor;
76
77// functions to be overriden.
78#define EXPR_FUNCTOR_DEFAULT \
79 { return VisitExprDefault_(op, std::forward<Args>(args)...); }
80
81#define IR_EXPR_FUNCTOR_DISPATCH(OP) \
82 vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \
83 return self->VisitExpr_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \
84 });
85
86template <typename R, typename... Args>
87class ExprFunctor<R(const PrimExpr& n, Args...)> {
88 private:
89 using TSelf = ExprFunctor<R(const PrimExpr& n, Args...)>;
90 using FType = NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
91
92 public:
93 /*! \brief the result type of this functor */
94 using result_type = R;
95 /*! \brief virtual destructor */
96 virtual ~ExprFunctor() {}
97 /*!
98 * \brief Same as call.
99 * \param n The expression node.
100 * \param args Additional arguments.
101 * \return The result of the call
102 */
103 R operator()(const PrimExpr& n, Args... args) {
104 return VisitExpr(n, std::forward<Args>(args)...);
105 }
106 /*!
107 * \brief The functor call.
108 * \param n The expression node.
109 * \param args Additional arguments.
110 * \return The result of the call
111 */
112 virtual R VisitExpr(const PrimExpr& n, Args... args) {
113 static FType vtable = InitVTable();
114 return vtable(n, this, std::forward<Args>(args)...);
115 }
116 // Functions that can be overriden by subclass
117 virtual R VisitExpr_(const VarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
118 virtual R VisitExpr_(const SizeVarNode* op, Args... args) {
119 return VisitExpr_(static_cast<const VarNode*>(op), std::forward<Args>(args)...);
120 }
121 virtual R VisitExpr_(const BufferLoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
122 virtual R VisitExpr_(const ProducerLoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
123 virtual R VisitExpr_(const LoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
124 virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
125 virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
126 virtual R VisitExpr_(const AddNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
127 virtual R VisitExpr_(const SubNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
128 virtual R VisitExpr_(const MulNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
129 virtual R VisitExpr_(const DivNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
130 virtual R VisitExpr_(const ModNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
131 virtual R VisitExpr_(const FloorDivNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
132 virtual R VisitExpr_(const FloorModNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
133 virtual R VisitExpr_(const MinNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
134 virtual R VisitExpr_(const MaxNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
135 virtual R VisitExpr_(const EQNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
136 virtual R VisitExpr_(const NENode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
137 virtual R VisitExpr_(const LTNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
138 virtual R VisitExpr_(const LENode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
139 virtual R VisitExpr_(const GTNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
140 virtual R VisitExpr_(const GENode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
141 virtual R VisitExpr_(const AndNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
142 virtual R VisitExpr_(const OrNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
143 virtual R VisitExpr_(const ReduceNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
144 virtual R VisitExpr_(const CastNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
145 virtual R VisitExpr_(const NotNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
146 virtual R VisitExpr_(const SelectNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
147 virtual R VisitExpr_(const RampNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
148 virtual R VisitExpr_(const BroadcastNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
149 virtual R VisitExpr_(const ShuffleNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
150 virtual R VisitExpr_(const IntImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
151 virtual R VisitExpr_(const FloatImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
152 virtual R VisitExpr_(const StringImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
153 virtual R VisitExpr_(const AnyNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
154 virtual R VisitExprDefault_(const Object* op, Args...) {
155 LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
156 }
157
158 private:
159 // initialize the vtable.
160 static FType InitVTable() {
161 FType vtable;
162 // Set dispatch
163 IR_EXPR_FUNCTOR_DISPATCH(VarNode);
164 IR_EXPR_FUNCTOR_DISPATCH(SizeVarNode);
165 IR_EXPR_FUNCTOR_DISPATCH(LoadNode);
166 IR_EXPR_FUNCTOR_DISPATCH(BufferLoadNode);
167 IR_EXPR_FUNCTOR_DISPATCH(ProducerLoadNode);
168 IR_EXPR_FUNCTOR_DISPATCH(LetNode);
169 IR_EXPR_FUNCTOR_DISPATCH(CallNode);
170 IR_EXPR_FUNCTOR_DISPATCH(AddNode);
171 IR_EXPR_FUNCTOR_DISPATCH(SubNode);
172 IR_EXPR_FUNCTOR_DISPATCH(MulNode);
173 IR_EXPR_FUNCTOR_DISPATCH(DivNode);
174 IR_EXPR_FUNCTOR_DISPATCH(ModNode);
175 IR_EXPR_FUNCTOR_DISPATCH(FloorDivNode);
176 IR_EXPR_FUNCTOR_DISPATCH(FloorModNode);
177 IR_EXPR_FUNCTOR_DISPATCH(MinNode);
178 IR_EXPR_FUNCTOR_DISPATCH(MaxNode);
179 IR_EXPR_FUNCTOR_DISPATCH(EQNode);
180 IR_EXPR_FUNCTOR_DISPATCH(NENode);
181 IR_EXPR_FUNCTOR_DISPATCH(LTNode);
182 IR_EXPR_FUNCTOR_DISPATCH(LENode);
183 IR_EXPR_FUNCTOR_DISPATCH(GTNode);
184 IR_EXPR_FUNCTOR_DISPATCH(GENode);
185 IR_EXPR_FUNCTOR_DISPATCH(AndNode);
186 IR_EXPR_FUNCTOR_DISPATCH(OrNode);
187 IR_EXPR_FUNCTOR_DISPATCH(ReduceNode);
188 IR_EXPR_FUNCTOR_DISPATCH(CastNode);
189 IR_EXPR_FUNCTOR_DISPATCH(NotNode);
190 IR_EXPR_FUNCTOR_DISPATCH(SelectNode);
191 IR_EXPR_FUNCTOR_DISPATCH(RampNode);
192 IR_EXPR_FUNCTOR_DISPATCH(ShuffleNode);
193 IR_EXPR_FUNCTOR_DISPATCH(BroadcastNode);
194 IR_EXPR_FUNCTOR_DISPATCH(IntImmNode);
195 IR_EXPR_FUNCTOR_DISPATCH(FloatImmNode);
196 IR_EXPR_FUNCTOR_DISPATCH(StringImmNode);
197 IR_EXPR_FUNCTOR_DISPATCH(AnyNode);
198 return vtable;
199 }
200};
201
202#undef IR_EXPR_FUNCTOR_DISPATCH
203#undef EXPR_FUNCTOR_DEFAULT
204
205/*!
206 * \brief ExprVisitor
207 */
208class TVM_DLL ExprVisitor : public ExprFunctor<void(const PrimExpr&)> {
209 public:
210 using ExprFunctor::operator();
211
212 protected:
213 using ExprFunctor::VisitExpr;
214 // list of functions to override.
215 void VisitExpr_(const VarNode* op) override;
216 void VisitExpr_(const SizeVarNode* op) override;
217 void VisitExpr_(const LoadNode* op) override;
218 void VisitExpr_(const BufferLoadNode* op) override;
219 void VisitExpr_(const ProducerLoadNode* op) override;
220 void VisitExpr_(const LetNode* op) override;
221 void VisitExpr_(const CallNode* op) override;
222 void VisitExpr_(const AddNode* op) override;
223 void VisitExpr_(const SubNode* op) override;
224 void VisitExpr_(const MulNode* op) override;
225 void VisitExpr_(const DivNode* op) override;
226 void VisitExpr_(const ModNode* op) override;
227 void VisitExpr_(const FloorDivNode* op) override;
228 void VisitExpr_(const FloorModNode* op) override;
229 void VisitExpr_(const MinNode* op) override;
230 void VisitExpr_(const MaxNode* op) override;
231 void VisitExpr_(const EQNode* op) override;
232 void VisitExpr_(const NENode* op) override;
233 void VisitExpr_(const LTNode* op) override;
234 void VisitExpr_(const LENode* op) override;
235 void VisitExpr_(const GTNode* op) override;
236 void VisitExpr_(const GENode* op) override;
237 void VisitExpr_(const AndNode* op) override;
238 void VisitExpr_(const OrNode* op) override;
239 void VisitExpr_(const ReduceNode* op) override;
240 void VisitExpr_(const CastNode* op) override;
241 void VisitExpr_(const NotNode* op) override;
242 void VisitExpr_(const SelectNode* op) override;
243 void VisitExpr_(const RampNode* op) override;
244 void VisitExpr_(const BroadcastNode* op) override;
245 void VisitExpr_(const ShuffleNode* op) override;
246 void VisitExpr_(const IntImmNode* op) override;
247 void VisitExpr_(const FloatImmNode* op) override;
248 void VisitExpr_(const StringImmNode* op) override;
249 void VisitExpr_(const AnyNode* op) override;
250};
251
252/*!
253 * \brief ExprMutator that mutates expressions.
254 */
255class TVM_DLL ExprMutator : protected ExprFunctor<PrimExpr(const PrimExpr&)> {
256 public:
257 using ExprFunctor::operator();
258
259 protected:
260 using ExprFunctor::VisitExpr;
261 // list of functions to override.
262 PrimExpr VisitExpr_(const VarNode* op) override;
263 PrimExpr VisitExpr_(const SizeVarNode* op) override;
264 PrimExpr VisitExpr_(const LoadNode* op) override;
265 PrimExpr VisitExpr_(const BufferLoadNode* op) override;
266 PrimExpr VisitExpr_(const ProducerLoadNode* op) override;
267 PrimExpr VisitExpr_(const LetNode* op) override;
268 PrimExpr VisitExpr_(const CallNode* op) override;
269 PrimExpr VisitExpr_(const AddNode* op) override;
270 PrimExpr VisitExpr_(const SubNode* op) override;
271 PrimExpr VisitExpr_(const MulNode* op) override;
272 PrimExpr VisitExpr_(const DivNode* op) override;
273 PrimExpr VisitExpr_(const ModNode* op) override;
274 PrimExpr VisitExpr_(const FloorDivNode* op) override;
275 PrimExpr VisitExpr_(const FloorModNode* op) override;
276 PrimExpr VisitExpr_(const MinNode* op) override;
277 PrimExpr VisitExpr_(const MaxNode* op) override;
278 PrimExpr VisitExpr_(const EQNode* op) override;
279 PrimExpr VisitExpr_(const NENode* op) override;
280 PrimExpr VisitExpr_(const LTNode* op) override;
281 PrimExpr VisitExpr_(const LENode* op) override;
282 PrimExpr VisitExpr_(const GTNode* op) override;
283 PrimExpr VisitExpr_(const GENode* op) override;
284 PrimExpr VisitExpr_(const AndNode* op) override;
285 PrimExpr VisitExpr_(const OrNode* op) override;
286 PrimExpr VisitExpr_(const ReduceNode* op) override;
287 PrimExpr VisitExpr_(const CastNode* op) override;
288 PrimExpr VisitExpr_(const NotNode* op) override;
289 PrimExpr VisitExpr_(const SelectNode* op) override;
290 PrimExpr VisitExpr_(const RampNode* op) override;
291 PrimExpr VisitExpr_(const BroadcastNode* op) override;
292 PrimExpr VisitExpr_(const ShuffleNode* op) override;
293 PrimExpr VisitExpr_(const IntImmNode* op) override;
294 PrimExpr VisitExpr_(const FloatImmNode* op) override;
295 PrimExpr VisitExpr_(const StringImmNode* op) override;
296 PrimExpr VisitExpr_(const AnyNode* op) override;
297};
298
299} // namespace tir
300} // namespace tvm
301#endif // TVM_TIR_EXPR_FUNCTOR_H_
302