1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/tir/expr_functor.h |
22 | * |
23 | * \brief Functors for tir expressions. |
24 | */ |
25 | #ifndef TVM_TIR_EXPR_FUNCTOR_H_ |
26 | #define TVM_TIR_EXPR_FUNCTOR_H_ |
27 | |
28 | #include <tvm/node/functor.h> |
29 | #include <tvm/tir/expr.h> |
30 | |
31 | #include <utility> |
32 | |
33 | namespace tvm { |
34 | namespace tir { |
35 | |
36 | /*! |
37 | * \brief A dynamical functor that dispatches on in the first Expr argument. |
38 | * You can use this as a more powerful Visitor, since it allows you to |
39 | * define function signatures of Visit Function. |
40 | * |
41 | * This helps you to avoid to book-keep return value of Visitor via state, |
42 | * which can cause bugs easily when state is incorrectly maintained. |
43 | * |
44 | * \code |
45 | * // A functor that set variable to b. and calculate results. |
46 | * class MyExprFunctor |
47 | * : public tir::ExprFunctor<int(const Expr&, int)> { |
48 | * public: |
49 | * int VisitExpr_(const Variable* op, int b) final { |
50 | * return b; |
51 | * } |
52 | * int VisitExpr_(const IntImm* op, int b) final { |
53 | * return op->value; |
54 | * } |
55 | * int VisitExpr_(const Add* op, int b) final { |
56 | * return Visit(op->a, b) + Visit(op->b, b); |
57 | * } |
58 | * }; |
59 | * MyExprFunctor f; |
60 | * Var x("x"); |
61 | * ICHECK_EQ(f(x + 1, 2), 3); |
62 | * \endcode |
63 | * |
64 | * \note Why do we need this more powerful Functor: |
65 | * |
66 | * We often need to implement a transformer tasks. |
67 | * Say we want to take Expr and transform it to some analysis result, |
68 | * This easily be done incorrectly using plain Visitor. See IRVisitor's |
69 | * document for possible error cases. |
70 | * |
71 | * \tparam FType function signiture |
72 | * This type if only defined for FType with function signiture R(const Expr&, Args...) |
73 | */ |
74 | template <typename FType> |
75 | class ExprFunctor; |
76 | |
77 | // functions to be overriden. |
78 | #define EXPR_FUNCTOR_DEFAULT \ |
79 | { return VisitExprDefault_(op, std::forward<Args>(args)...); } |
80 | |
81 | #define IR_EXPR_FUNCTOR_DISPATCH(OP) \ |
82 | vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \ |
83 | return self->VisitExpr_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \ |
84 | }); |
85 | |
86 | template <typename R, typename... Args> |
87 | class ExprFunctor<R(const PrimExpr& n, Args...)> { |
88 | private: |
89 | using TSelf = ExprFunctor<R(const PrimExpr& n, Args...)>; |
90 | using FType = NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>; |
91 | |
92 | public: |
93 | /*! \brief the result type of this functor */ |
94 | using result_type = R; |
95 | /*! \brief virtual destructor */ |
96 | virtual ~ExprFunctor() {} |
97 | /*! |
98 | * \brief Same as call. |
99 | * \param n The expression node. |
100 | * \param args Additional arguments. |
101 | * \return The result of the call |
102 | */ |
103 | R operator()(const PrimExpr& n, Args... args) { |
104 | return VisitExpr(n, std::forward<Args>(args)...); |
105 | } |
106 | /*! |
107 | * \brief The functor call. |
108 | * \param n The expression node. |
109 | * \param args Additional arguments. |
110 | * \return The result of the call |
111 | */ |
112 | virtual R VisitExpr(const PrimExpr& n, Args... args) { |
113 | static FType vtable = InitVTable(); |
114 | return vtable(n, this, std::forward<Args>(args)...); |
115 | } |
116 | // Functions that can be overriden by subclass |
117 | virtual R VisitExpr_(const VarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
118 | virtual R VisitExpr_(const SizeVarNode* op, Args... args) { |
119 | return VisitExpr_(static_cast<const VarNode*>(op), std::forward<Args>(args)...); |
120 | } |
121 | virtual R VisitExpr_(const BufferLoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
122 | virtual R VisitExpr_(const ProducerLoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
123 | virtual R VisitExpr_(const LoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
124 | virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
125 | virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
126 | virtual R VisitExpr_(const AddNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
127 | virtual R VisitExpr_(const SubNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
128 | virtual R VisitExpr_(const MulNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
129 | virtual R VisitExpr_(const DivNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
130 | virtual R VisitExpr_(const ModNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
131 | virtual R VisitExpr_(const FloorDivNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
132 | virtual R VisitExpr_(const FloorModNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
133 | virtual R VisitExpr_(const MinNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
134 | virtual R VisitExpr_(const MaxNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
135 | virtual R VisitExpr_(const EQNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
136 | virtual R VisitExpr_(const NENode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
137 | virtual R VisitExpr_(const LTNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
138 | virtual R VisitExpr_(const LENode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
139 | virtual R VisitExpr_(const GTNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
140 | virtual R VisitExpr_(const GENode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
141 | virtual R VisitExpr_(const AndNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
142 | virtual R VisitExpr_(const OrNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
143 | virtual R VisitExpr_(const ReduceNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
144 | virtual R VisitExpr_(const CastNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
145 | virtual R VisitExpr_(const NotNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
146 | virtual R VisitExpr_(const SelectNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
147 | virtual R VisitExpr_(const RampNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
148 | virtual R VisitExpr_(const BroadcastNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
149 | virtual R VisitExpr_(const ShuffleNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
150 | virtual R VisitExpr_(const IntImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
151 | virtual R VisitExpr_(const FloatImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
152 | virtual R VisitExpr_(const StringImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
153 | virtual R VisitExpr_(const AnyNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; |
154 | virtual R VisitExprDefault_(const Object* op, Args...) { |
155 | LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); |
156 | } |
157 | |
158 | private: |
159 | // initialize the vtable. |
160 | static FType InitVTable() { |
161 | FType vtable; |
162 | // Set dispatch |
163 | IR_EXPR_FUNCTOR_DISPATCH(VarNode); |
164 | IR_EXPR_FUNCTOR_DISPATCH(SizeVarNode); |
165 | IR_EXPR_FUNCTOR_DISPATCH(LoadNode); |
166 | IR_EXPR_FUNCTOR_DISPATCH(BufferLoadNode); |
167 | IR_EXPR_FUNCTOR_DISPATCH(ProducerLoadNode); |
168 | IR_EXPR_FUNCTOR_DISPATCH(LetNode); |
169 | IR_EXPR_FUNCTOR_DISPATCH(CallNode); |
170 | IR_EXPR_FUNCTOR_DISPATCH(AddNode); |
171 | IR_EXPR_FUNCTOR_DISPATCH(SubNode); |
172 | IR_EXPR_FUNCTOR_DISPATCH(MulNode); |
173 | IR_EXPR_FUNCTOR_DISPATCH(DivNode); |
174 | IR_EXPR_FUNCTOR_DISPATCH(ModNode); |
175 | IR_EXPR_FUNCTOR_DISPATCH(FloorDivNode); |
176 | IR_EXPR_FUNCTOR_DISPATCH(FloorModNode); |
177 | IR_EXPR_FUNCTOR_DISPATCH(MinNode); |
178 | IR_EXPR_FUNCTOR_DISPATCH(MaxNode); |
179 | IR_EXPR_FUNCTOR_DISPATCH(EQNode); |
180 | IR_EXPR_FUNCTOR_DISPATCH(NENode); |
181 | IR_EXPR_FUNCTOR_DISPATCH(LTNode); |
182 | IR_EXPR_FUNCTOR_DISPATCH(LENode); |
183 | IR_EXPR_FUNCTOR_DISPATCH(GTNode); |
184 | IR_EXPR_FUNCTOR_DISPATCH(GENode); |
185 | IR_EXPR_FUNCTOR_DISPATCH(AndNode); |
186 | IR_EXPR_FUNCTOR_DISPATCH(OrNode); |
187 | IR_EXPR_FUNCTOR_DISPATCH(ReduceNode); |
188 | IR_EXPR_FUNCTOR_DISPATCH(CastNode); |
189 | IR_EXPR_FUNCTOR_DISPATCH(NotNode); |
190 | IR_EXPR_FUNCTOR_DISPATCH(SelectNode); |
191 | IR_EXPR_FUNCTOR_DISPATCH(RampNode); |
192 | IR_EXPR_FUNCTOR_DISPATCH(ShuffleNode); |
193 | IR_EXPR_FUNCTOR_DISPATCH(BroadcastNode); |
194 | IR_EXPR_FUNCTOR_DISPATCH(IntImmNode); |
195 | IR_EXPR_FUNCTOR_DISPATCH(FloatImmNode); |
196 | IR_EXPR_FUNCTOR_DISPATCH(StringImmNode); |
197 | IR_EXPR_FUNCTOR_DISPATCH(AnyNode); |
198 | return vtable; |
199 | } |
200 | }; |
201 | |
202 | #undef IR_EXPR_FUNCTOR_DISPATCH |
203 | #undef EXPR_FUNCTOR_DEFAULT |
204 | |
205 | /*! |
206 | * \brief ExprVisitor |
207 | */ |
208 | class TVM_DLL ExprVisitor : public ExprFunctor<void(const PrimExpr&)> { |
209 | public: |
210 | using ExprFunctor::operator(); |
211 | |
212 | protected: |
213 | using ExprFunctor::VisitExpr; |
214 | // list of functions to override. |
215 | void VisitExpr_(const VarNode* op) override; |
216 | void VisitExpr_(const SizeVarNode* op) override; |
217 | void VisitExpr_(const LoadNode* op) override; |
218 | void VisitExpr_(const BufferLoadNode* op) override; |
219 | void VisitExpr_(const ProducerLoadNode* op) override; |
220 | void VisitExpr_(const LetNode* op) override; |
221 | void VisitExpr_(const CallNode* op) override; |
222 | void VisitExpr_(const AddNode* op) override; |
223 | void VisitExpr_(const SubNode* op) override; |
224 | void VisitExpr_(const MulNode* op) override; |
225 | void VisitExpr_(const DivNode* op) override; |
226 | void VisitExpr_(const ModNode* op) override; |
227 | void VisitExpr_(const FloorDivNode* op) override; |
228 | void VisitExpr_(const FloorModNode* op) override; |
229 | void VisitExpr_(const MinNode* op) override; |
230 | void VisitExpr_(const MaxNode* op) override; |
231 | void VisitExpr_(const EQNode* op) override; |
232 | void VisitExpr_(const NENode* op) override; |
233 | void VisitExpr_(const LTNode* op) override; |
234 | void VisitExpr_(const LENode* op) override; |
235 | void VisitExpr_(const GTNode* op) override; |
236 | void VisitExpr_(const GENode* op) override; |
237 | void VisitExpr_(const AndNode* op) override; |
238 | void VisitExpr_(const OrNode* op) override; |
239 | void VisitExpr_(const ReduceNode* op) override; |
240 | void VisitExpr_(const CastNode* op) override; |
241 | void VisitExpr_(const NotNode* op) override; |
242 | void VisitExpr_(const SelectNode* op) override; |
243 | void VisitExpr_(const RampNode* op) override; |
244 | void VisitExpr_(const BroadcastNode* op) override; |
245 | void VisitExpr_(const ShuffleNode* op) override; |
246 | void VisitExpr_(const IntImmNode* op) override; |
247 | void VisitExpr_(const FloatImmNode* op) override; |
248 | void VisitExpr_(const StringImmNode* op) override; |
249 | void VisitExpr_(const AnyNode* op) override; |
250 | }; |
251 | |
252 | /*! |
253 | * \brief ExprMutator that mutates expressions. |
254 | */ |
255 | class TVM_DLL ExprMutator : protected ExprFunctor<PrimExpr(const PrimExpr&)> { |
256 | public: |
257 | using ExprFunctor::operator(); |
258 | |
259 | protected: |
260 | using ExprFunctor::VisitExpr; |
261 | // list of functions to override. |
262 | PrimExpr VisitExpr_(const VarNode* op) override; |
263 | PrimExpr VisitExpr_(const SizeVarNode* op) override; |
264 | PrimExpr VisitExpr_(const LoadNode* op) override; |
265 | PrimExpr VisitExpr_(const BufferLoadNode* op) override; |
266 | PrimExpr VisitExpr_(const ProducerLoadNode* op) override; |
267 | PrimExpr VisitExpr_(const LetNode* op) override; |
268 | PrimExpr VisitExpr_(const CallNode* op) override; |
269 | PrimExpr VisitExpr_(const AddNode* op) override; |
270 | PrimExpr VisitExpr_(const SubNode* op) override; |
271 | PrimExpr VisitExpr_(const MulNode* op) override; |
272 | PrimExpr VisitExpr_(const DivNode* op) override; |
273 | PrimExpr VisitExpr_(const ModNode* op) override; |
274 | PrimExpr VisitExpr_(const FloorDivNode* op) override; |
275 | PrimExpr VisitExpr_(const FloorModNode* op) override; |
276 | PrimExpr VisitExpr_(const MinNode* op) override; |
277 | PrimExpr VisitExpr_(const MaxNode* op) override; |
278 | PrimExpr VisitExpr_(const EQNode* op) override; |
279 | PrimExpr VisitExpr_(const NENode* op) override; |
280 | PrimExpr VisitExpr_(const LTNode* op) override; |
281 | PrimExpr VisitExpr_(const LENode* op) override; |
282 | PrimExpr VisitExpr_(const GTNode* op) override; |
283 | PrimExpr VisitExpr_(const GENode* op) override; |
284 | PrimExpr VisitExpr_(const AndNode* op) override; |
285 | PrimExpr VisitExpr_(const OrNode* op) override; |
286 | PrimExpr VisitExpr_(const ReduceNode* op) override; |
287 | PrimExpr VisitExpr_(const CastNode* op) override; |
288 | PrimExpr VisitExpr_(const NotNode* op) override; |
289 | PrimExpr VisitExpr_(const SelectNode* op) override; |
290 | PrimExpr VisitExpr_(const RampNode* op) override; |
291 | PrimExpr VisitExpr_(const BroadcastNode* op) override; |
292 | PrimExpr VisitExpr_(const ShuffleNode* op) override; |
293 | PrimExpr VisitExpr_(const IntImmNode* op) override; |
294 | PrimExpr VisitExpr_(const FloatImmNode* op) override; |
295 | PrimExpr VisitExpr_(const StringImmNode* op) override; |
296 | PrimExpr VisitExpr_(const AnyNode* op) override; |
297 | }; |
298 | |
299 | } // namespace tir |
300 | } // namespace tvm |
301 | #endif // TVM_TIR_EXPR_FUNCTOR_H_ |
302 | |