1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/tir/stmt_functor.h |
22 | * |
23 | * \brief Functors for tir stmts |
24 | * utility functions to call common functors. |
25 | */ |
26 | #ifndef TVM_TIR_STMT_FUNCTOR_H_ |
27 | #define TVM_TIR_STMT_FUNCTOR_H_ |
28 | |
29 | #include <tvm/node/functor.h> |
30 | #include <tvm/tir/expr.h> |
31 | #include <tvm/tir/expr_functor.h> |
32 | #include <tvm/tir/function.h> |
33 | #include <tvm/tir/stmt.h> |
34 | |
35 | #include <unordered_map> |
36 | #include <utility> |
37 | |
38 | namespace tvm { |
39 | namespace tir { |
40 | /*! |
41 | * \brief Same as ExprFunctor except it is applied on statements |
42 | * \tparam FType The function signature. |
43 | * \sa ExprFunctor |
44 | */ |
45 | template <typename FType> |
46 | class StmtFunctor; |
47 | |
48 | #define STMT_FUNCTOR_DEFAULT \ |
49 | { return VisitStmtDefault_(op, std::forward<Args>(args)...); } |
50 | |
51 | #define IR_STMT_FUNCTOR_DISPATCH(OP) \ |
52 | vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \ |
53 | return self->VisitStmt_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \ |
54 | }); |
55 | |
56 | template <typename R, typename... Args> |
57 | class StmtFunctor<R(const Stmt& n, Args... args)> { |
58 | private: |
59 | using TSelf = StmtFunctor<R(const Stmt& n, Args... args)>; |
60 | using FType = NodeFunctor<R(const ObjectRef& n, TSelf* self, Args... args)>; |
61 | |
62 | public: |
63 | /*! \brief the result type of this functor */ |
64 | using result_type = R; |
65 | /*! \brief virtual destructor */ |
66 | virtual ~StmtFunctor() {} |
67 | /*! |
68 | * \brief Same as call. |
69 | * \param n The stmt node. |
70 | * \param args Additional arguments. |
71 | * \return The result of the call |
72 | */ |
73 | R operator()(const Stmt& n, Args... args) { return VisitStmt(n, std::forward<Args>(args)...); } |
74 | /*! |
75 | * \brief The functor call. |
76 | * \param n The stmt node. |
77 | * \param args Additional arguments. |
78 | * \return The result of the call |
79 | */ |
80 | virtual R VisitStmt(const Stmt& n, Args... args) { |
81 | static FType vtable = InitVTable(); |
82 | return vtable(n, this, std::forward<Args>(args)...); |
83 | } |
84 | // Functions that can be overriden by subclass |
85 | virtual R VisitStmt_(const LetStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; |
86 | virtual R VisitStmt_(const AttrStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; |
87 | virtual R VisitStmt_(const IfThenElseNode* op, Args... args) STMT_FUNCTOR_DEFAULT; |
88 | virtual R VisitStmt_(const ForNode* op, Args... args) STMT_FUNCTOR_DEFAULT; |
89 | virtual R VisitStmt_(const WhileNode* op, Args... args) STMT_FUNCTOR_DEFAULT; |
90 | virtual R VisitStmt_(const AllocateNode* op, Args... args) STMT_FUNCTOR_DEFAULT; |
91 | virtual R VisitStmt_(const AllocateConstNode* op, Args... args) STMT_FUNCTOR_DEFAULT; |
92 | virtual R VisitStmt_(const DeclBufferNode* op, Args... args) STMT_FUNCTOR_DEFAULT; |
93 | virtual R VisitStmt_(const StoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; |
94 | virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; |
95 | virtual R VisitStmt_(const BufferRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT; |
96 | virtual R VisitStmt_(const AssertStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; |
97 | virtual R VisitStmt_(const ProducerStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; |
98 | virtual R VisitStmt_(const ProducerRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT; |
99 | virtual R VisitStmt_(const PrefetchNode* op, Args... args) STMT_FUNCTOR_DEFAULT; |
100 | virtual R VisitStmt_(const SeqStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; |
101 | virtual R VisitStmt_(const EvaluateNode* op, Args... args) STMT_FUNCTOR_DEFAULT; |
102 | virtual R VisitStmt_(const BlockNode* op, Args... args) STMT_FUNCTOR_DEFAULT; |
103 | virtual R VisitStmt_(const BlockRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT; |
104 | virtual R VisitStmtDefault_(const Object* op, Args...) { |
105 | LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); |
106 | } |
107 | |
108 | private: |
109 | // initialize the vtable. |
110 | static FType InitVTable() { |
111 | FType vtable; |
112 | IR_STMT_FUNCTOR_DISPATCH(LetStmtNode); |
113 | IR_STMT_FUNCTOR_DISPATCH(AttrStmtNode); |
114 | IR_STMT_FUNCTOR_DISPATCH(IfThenElseNode); |
115 | IR_STMT_FUNCTOR_DISPATCH(ForNode); |
116 | IR_STMT_FUNCTOR_DISPATCH(WhileNode); |
117 | IR_STMT_FUNCTOR_DISPATCH(AllocateNode); |
118 | IR_STMT_FUNCTOR_DISPATCH(AllocateConstNode); |
119 | IR_STMT_FUNCTOR_DISPATCH(DeclBufferNode); |
120 | IR_STMT_FUNCTOR_DISPATCH(StoreNode); |
121 | IR_STMT_FUNCTOR_DISPATCH(AssertStmtNode); |
122 | IR_STMT_FUNCTOR_DISPATCH(ProducerStoreNode); |
123 | IR_STMT_FUNCTOR_DISPATCH(ProducerRealizeNode); |
124 | IR_STMT_FUNCTOR_DISPATCH(PrefetchNode); |
125 | IR_STMT_FUNCTOR_DISPATCH(SeqStmtNode); |
126 | IR_STMT_FUNCTOR_DISPATCH(EvaluateNode); |
127 | IR_STMT_FUNCTOR_DISPATCH(BufferStoreNode); |
128 | IR_STMT_FUNCTOR_DISPATCH(BufferRealizeNode); |
129 | IR_STMT_FUNCTOR_DISPATCH(BlockNode); |
130 | IR_STMT_FUNCTOR_DISPATCH(BlockRealizeNode); |
131 | return vtable; |
132 | } |
133 | }; |
134 | |
135 | #undef IR_STMT_FUNCTOR_DISPATCH |
136 | #undef STMT_FUNCTOR_DEFAULT |
137 | |
138 | /*! |
139 | * \brief StmtVisitor. |
140 | */ |
141 | class TVM_DLL StmtVisitor : protected StmtFunctor<void(const Stmt&)> { |
142 | public: |
143 | using StmtFunctor::operator(); |
144 | |
145 | protected: |
146 | using StmtFunctor::VisitStmt; |
147 | /*! |
148 | * \brief Visitor to Exprs, can be overriden |
149 | * to do recursive changes to Exprs. |
150 | * \note A common pattern is to call ExprVisitor here, |
151 | * or have a class sub-class both StmtVisitor and ExprVisitor |
152 | * and redirect Visit to ExprMutator::VisitExpr(Expr) |
153 | */ |
154 | virtual void VisitExpr(const PrimExpr& e) {} |
155 | // statement visitor |
156 | void VisitStmt_(const AttrStmtNode* op) override; |
157 | void VisitStmt_(const IfThenElseNode* op) override; |
158 | void VisitStmt_(const LetStmtNode* op) override; |
159 | void VisitStmt_(const ForNode* op) override; |
160 | void VisitStmt_(const WhileNode* op) override; |
161 | void VisitStmt_(const AllocateNode* op) override; |
162 | void VisitStmt_(const AllocateConstNode* op) override; |
163 | void VisitStmt_(const DeclBufferNode* op) override; |
164 | void VisitStmt_(const StoreNode* op) override; |
165 | void VisitStmt_(const BufferStoreNode* op) override; |
166 | void VisitStmt_(const BufferRealizeNode* op) override; |
167 | void VisitStmt_(const AssertStmtNode* op) override; |
168 | void VisitStmt_(const ProducerStoreNode* op) override; |
169 | void VisitStmt_(const ProducerRealizeNode* op) override; |
170 | void VisitStmt_(const PrefetchNode* op) override; |
171 | void VisitStmt_(const SeqStmtNode* op) override; |
172 | void VisitStmt_(const EvaluateNode* op) override; |
173 | void VisitStmt_(const BlockNode* op) override; |
174 | void VisitStmt_(const BlockRealizeNode* op) override; |
175 | }; |
176 | |
177 | /*! |
178 | * \brief StmtMutator that mutates the statements. |
179 | */ |
180 | class TVM_DLL StmtMutator : protected StmtFunctor<Stmt(const Stmt&)> { |
181 | public: |
182 | /*! |
183 | * \brief Mutate stmt. |
184 | * \param stmt The input statement to be mutated. |
185 | * \return The result of the call |
186 | * \note It is important that stmt is passed by value. |
187 | * so copy on write can be triggered correctly. |
188 | * do mutator(std::move(stmt)) or when copy elison is triggered. |
189 | */ |
190 | Stmt operator()(Stmt stmt) { |
191 | allow_copy_on_write_ = true; |
192 | return VisitStmt(stmt); |
193 | } |
194 | |
195 | protected: |
196 | // We perform copy on write optimizations on the StmtMutator |
197 | // so that an unique copy of parent can be mutated inplace |
198 | // when some of its children changed. |
199 | // We only do such optimization for Stmt nests(instead of Exprs) for now |
200 | // as Stmt's parent state is more likely remain unchanged when one of |
201 | // its child block changes. |
202 | /*! |
203 | * \brief Internal state to indicate whether copy on write is enabled. |
204 | * COW is enabled iff all the parents of the node are unique. |
205 | */ |
206 | bool allow_copy_on_write_{false}; |
207 | /*! |
208 | * \brief Perform copy on write on node. |
209 | * |
210 | * If CopyOnWrite is allowed, directly return |
211 | * a strong reference to the node container. |
212 | * Otherwise, return a copy of the node. |
213 | * |
214 | * \return The result object pointer. |
215 | */ |
216 | template <typename TNode> |
217 | ObjectPtr<TNode> CopyOnWrite(const TNode* node) { |
218 | static_assert(std::is_base_of<StmtNode, TNode>::value, |
219 | "StmtMutator:: CopyOnWrite requires us to track uniqueness of all parent " |
220 | "nodes during the recursion. Because the child classes do not necessarily " |
221 | "check the Array, Expr and other structures during the visit, it is only safe to " |
222 | "call this function with StmtNodes for now. " |
223 | "Please create a new node directly in other cases." ); |
224 | if (allow_copy_on_write_) { |
225 | // return the old node. |
226 | return runtime::GetObjectPtr<TNode>(const_cast<TNode*>(node)); |
227 | } else { |
228 | // Make a new copy of the node. |
229 | // need to rely on the default copy constructor |
230 | return runtime::make_object<TNode>(*node); |
231 | } |
232 | } |
233 | /*! |
234 | * \brief Internal mutator that everyone calls. |
235 | * \note To override mutate's behavior, override VisitExpr instead. |
236 | * \param stmt The input stmt. |
237 | * \return The mutated results. |
238 | */ |
239 | Stmt VisitStmt(const Stmt& stmt) override { |
240 | if (allow_copy_on_write_ && !stmt.unique()) { |
241 | allow_copy_on_write_ = false; |
242 | Stmt ret = StmtFunctor::VisitStmt(stmt); |
243 | allow_copy_on_write_ = true; |
244 | return ret; |
245 | } else { |
246 | return StmtFunctor::VisitStmt(stmt); |
247 | } |
248 | } |
249 | /*! |
250 | * \brief Visitor to Exprs, can be overriden |
251 | * to do recursive changes to Exprs. |
252 | * \note A common pattern is to call ExprMutator here, |
253 | * or have a class sub-class both StmtMutator and ExprMutator |
254 | * and redirect Mutate to ExprMutator::Mutate(Expr) |
255 | */ |
256 | virtual PrimExpr VisitExpr(const PrimExpr& e) { return e; } |
257 | // statement visitor |
258 | Stmt VisitStmt_(const AttrStmtNode* op) override; |
259 | Stmt VisitStmt_(const IfThenElseNode* op) override; |
260 | Stmt VisitStmt_(const LetStmtNode* op) override; |
261 | Stmt VisitStmt_(const ForNode* op) override; |
262 | Stmt VisitStmt_(const WhileNode* op) override; |
263 | Stmt VisitStmt_(const AllocateNode* op) override; |
264 | Stmt VisitStmt_(const AllocateConstNode* op) override; |
265 | Stmt VisitStmt_(const DeclBufferNode* op) override; |
266 | Stmt VisitStmt_(const StoreNode* op) override; |
267 | Stmt VisitStmt_(const BufferStoreNode* op) override; |
268 | Stmt VisitStmt_(const BufferRealizeNode* op) override; |
269 | Stmt VisitStmt_(const AssertStmtNode* op) override; |
270 | Stmt VisitStmt_(const ProducerStoreNode* op) override; |
271 | Stmt VisitStmt_(const ProducerRealizeNode* op) override; |
272 | Stmt VisitStmt_(const PrefetchNode* op) override; |
273 | Stmt VisitStmt_(const SeqStmtNode* op) override; |
274 | Stmt VisitStmt_(const EvaluateNode* op) override; |
275 | Stmt VisitStmt_(const BlockNode* op) override; |
276 | Stmt VisitStmt_(const BlockRealizeNode* op) override; |
277 | /*! |
278 | * \brief Alternative advance method for SeqStmtNode. |
279 | * |
280 | * This function can be called when a child class override |
281 | * VisitStmt_(const SeqStmtNode*) to introduce |
282 | * the special behavior to visit |
283 | * |
284 | * \param op The sequence. |
285 | * \param flatten_before_visit Whether to flatten the sequence before visit. |
286 | * \param fmutate The mutate function, can be nullptr, which defaults to Visit. |
287 | * \return The mutated result. |
288 | */ |
289 | Stmt VisitSeqStmt_(const SeqStmtNode* op, bool flatten_before_visit, |
290 | std::function<Stmt(const Stmt&)> fmutate = nullptr); |
291 | |
292 | // internal helper. |
293 | class Internal; |
294 | }; |
295 | |
296 | /*! |
297 | * \brief Visitor that recursively visit stmts and exprs on them. |
298 | */ |
299 | class StmtExprVisitor : public StmtVisitor, public ExprVisitor { |
300 | public: |
301 | using StmtVisitor::operator(); |
302 | using ExprVisitor::operator(); |
303 | |
304 | protected: |
305 | using ExprVisitor::VisitExpr; |
306 | using StmtVisitor::VisitStmt; |
307 | |
308 | void VisitExpr(const PrimExpr& e) override { return ExprVisitor::VisitExpr(e); } |
309 | }; |
310 | |
311 | /*! |
312 | * \brief Mutator that recursively mutates stmts and exprs on them. |
313 | */ |
314 | class StmtExprMutator : public StmtMutator, public ExprMutator { |
315 | public: |
316 | using StmtMutator::operator(); |
317 | using ExprMutator::operator(); |
318 | |
319 | protected: |
320 | using ExprMutator::VisitExpr; |
321 | using StmtMutator::VisitExpr; |
322 | |
323 | PrimExpr VisitExpr(const PrimExpr& e) override { return ExprMutator::VisitExpr(e); } |
324 | }; |
325 | |
326 | /*! |
327 | * \brief recursively visit the ir nodes in post DFS order, and transform it |
328 | * |
329 | * \param stmt The ir to be transformed. |
330 | * \param preorder The function called in before recursive mutation |
331 | * If preorder returns None, then the transform will proceed to recursive call. |
332 | * If preorder returns a not None Stmt/Expr, the transformer will simply return it and |
333 | * won't do further recursion. |
334 | * \param postorder The function called after recursive mutation. |
335 | * The recursive mutation result is passed to postorder for further mutation. |
336 | * \param only_enable List of runtime::String. |
337 | * If it is null, all IRNode will call preorder/postorder |
338 | * If it is not null, preorder/postorder will only be called |
339 | * when the IRNode's type key is in the list. |
340 | */ |
341 | TVM_DLL Stmt IRTransform(Stmt stmt, const runtime::PackedFunc& preorder, |
342 | const runtime::PackedFunc& postorder, |
343 | Optional<Array<String>> only_enable = NullOpt); |
344 | |
345 | /*! |
346 | * \brief Recursively visit the ir in post DFS order node, apply fvisit |
347 | * Each node is guaranteed to be visited only once. |
348 | * \param node The ir to be visited. |
349 | * \param fvisit The visitor function to be applied. |
350 | */ |
351 | TVM_DLL void PostOrderVisit(const ObjectRef& node, std::function<void(const ObjectRef&)> fvisit); |
352 | |
353 | /*! |
354 | * \brief Substitute the var specified by vmap. |
355 | * \param stmt The source statement to be substituted |
356 | * \param vmap returns a new value if re-mapping is needed, otherwise returns nullptr. |
357 | * \return The converted form. |
358 | */ |
359 | TVM_DLL Stmt Substitute(Stmt stmt, std::function<Optional<PrimExpr>(const Var& var)> vmap); |
360 | |
361 | /*! |
362 | * \brief Substitute the var specified by vmap. |
363 | * \param expr The source statement to be substituted |
364 | * \param vmap returns a new value if re-mapping is needed, otherwise returns nullptr. |
365 | * \return The result. |
366 | */ |
367 | TVM_DLL PrimExpr Substitute(PrimExpr expr, std::function<Optional<PrimExpr>(const Var& var)> vmap); |
368 | |
369 | /*! |
370 | * \brief Substitute the var specified by vmap. |
371 | * \param region The object whose vars are to be substituted |
372 | * \param vmap The map of new values. |
373 | * \return The result. |
374 | */ |
375 | TVM_DLL Array<Range> Substitute(const Array<Range>& region, const Map<Var, PrimExpr>& vmap); |
376 | |
377 | /*! |
378 | * \brief Sugar for substitute via a given map. |
379 | * \param input The input to be updated. |
380 | * \param value_map The map of new values. |
381 | * \return The result. |
382 | * \tparam T the input type, can be PrimExpr or Stmt. |
383 | */ |
384 | template <typename T> |
385 | inline auto Substitute(T input, const Map<Var, PrimExpr>& value_map) { |
386 | auto vmap = [&](const Var& var) -> Optional<PrimExpr> { |
387 | auto it = value_map.find(var); |
388 | if (it != value_map.end()) return (*it).second; |
389 | return Optional<PrimExpr>(nullptr); |
390 | }; |
391 | return Substitute(std::move(input), vmap); |
392 | } |
393 | |
394 | /*! |
395 | * \brief Sugar for substitute via a given map. |
396 | * \param input The input to be updated. |
397 | * \param value_map The map of new values. |
398 | * \return The result. |
399 | * \tparam T the input type, can be PrimExpr or Stmt. |
400 | */ |
401 | template <typename T> |
402 | inline T Substitute(T input, const std::unordered_map<const VarNode*, PrimExpr>& value_map) { |
403 | auto vmap = [&](const Var& var) -> Optional<PrimExpr> { |
404 | auto it = value_map.find(var.get()); |
405 | if (it != value_map.end()) return (*it).second; |
406 | return Optional<PrimExpr>(nullptr); |
407 | }; |
408 | return Substitute(std::move(input), vmap); |
409 | } |
410 | |
411 | /*! |
412 | * \brief Substitute the var specified by vmap and legalize data types after substitution. |
413 | * \param stmt The source statement to be substituted |
414 | * \param vmap returns a new value if re-mapping is needed, otherwise returns nullptr. |
415 | * |
416 | * Unlike `Substitute`, this allows the substitution to change the data type of the expression. |
417 | * |
418 | * \sa Substitute |
419 | * \return The result. |
420 | */ |
421 | TVM_DLL Stmt SubstituteWithDataTypeLegalization(Stmt stmt, |
422 | std::function<Optional<PrimExpr>(const Var&)> vmap); |
423 | |
424 | /*! |
425 | * \brief Substitute the var specified by vmap and legalize data types after substitution. |
426 | * \param expr The source statement to be substituted |
427 | * \param vmap returns a new value if re-mapping is needed, otherwise returns nullptr. |
428 | * |
429 | * Unlike `Substitute`, this allows the substitution to change the data type of the expression. |
430 | * |
431 | * \sa Substitute |
432 | * \return The result. |
433 | */ |
434 | TVM_DLL PrimExpr SubstituteWithDataTypeLegalization( |
435 | PrimExpr expr, std::function<Optional<PrimExpr>(const Var&)> vmap); |
436 | |
437 | /*! |
438 | * \brief Recursively visit the IR in pre DFS order node, apply fvisit. |
439 | * If fvisit returns false, it won't visit the children of the node. |
440 | * \param stmt_or_expr The ir to be visited. |
441 | * \param fvisit The visitor function to be applied. If fvisit returns false, it won't visit the |
442 | * children of the node |
443 | */ |
444 | TVM_DLL void PreOrderVisit(const ObjectRef& stmt_or_expr, |
445 | const std::function<bool(const ObjectRef&)>& fvisit); |
446 | |
447 | /*! |
448 | * \brief Renew the definition nodes for a TIR, including Var, Buffer and IterVar. |
449 | * This pass works as a simple DeepCopy to duplicate a function with different Vars and |
450 | * Buffers but the same behavior |
451 | * \param func The input PrimFunc. |
452 | * \return The renewed func. |
453 | */ |
454 | TVM_DLL PrimFunc RenewDefs(const PrimFunc& func); |
455 | |
456 | /*! |
457 | * \brief Check if the statement contains the specified node type. |
458 | * |
459 | * This utility potentially walks the entire statement, and should |
460 | * therefore not be used if it could otherwise be merged with another |
461 | * pass. |
462 | * |
463 | * \param stmt The statement to be searched |
464 | * \return Whether stmt contains Node |
465 | */ |
466 | template <typename Node, typename = std::enable_if_t<std::is_base_of_v<StmtNode, Node>>> |
467 | bool ContainsNode(const Stmt& stmt) { |
468 | struct Visitor : StmtVisitor { |
469 | // Early bail-out, if we already found the node. |
470 | void VisitStmt(const Stmt& stmt) final { |
471 | if (contains_node) { |
472 | return; |
473 | } |
474 | StmtVisitor::VisitStmt(stmt); |
475 | } |
476 | |
477 | void VisitStmt_(const Node* block) override { contains_node = true; } |
478 | |
479 | bool contains_node{false}; |
480 | }; |
481 | |
482 | Visitor visitor; |
483 | visitor(stmt); |
484 | return visitor.contains_node; |
485 | } |
486 | |
487 | } // namespace tir |
488 | } // namespace tvm |
489 | |
490 | #endif // TVM_TIR_STMT_FUNCTOR_H_ |
491 | |