1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/tir/stmt_functor.h
22 *
23 * \brief Functors for tir stmts
24 * utility functions to call common functors.
25 */
26#ifndef TVM_TIR_STMT_FUNCTOR_H_
27#define TVM_TIR_STMT_FUNCTOR_H_
28
29#include <tvm/node/functor.h>
30#include <tvm/tir/expr.h>
31#include <tvm/tir/expr_functor.h>
32#include <tvm/tir/function.h>
33#include <tvm/tir/stmt.h>
34
35#include <unordered_map>
36#include <utility>
37
38namespace tvm {
39namespace tir {
40/*!
41 * \brief Same as ExprFunctor except it is applied on statements
42 * \tparam FType The function signature.
43 * \sa ExprFunctor
44 */
45template <typename FType>
46class StmtFunctor;
47
48#define STMT_FUNCTOR_DEFAULT \
49 { return VisitStmtDefault_(op, std::forward<Args>(args)...); }
50
51#define IR_STMT_FUNCTOR_DISPATCH(OP) \
52 vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \
53 return self->VisitStmt_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \
54 });
55
56template <typename R, typename... Args>
57class StmtFunctor<R(const Stmt& n, Args... args)> {
58 private:
59 using TSelf = StmtFunctor<R(const Stmt& n, Args... args)>;
60 using FType = NodeFunctor<R(const ObjectRef& n, TSelf* self, Args... args)>;
61
62 public:
63 /*! \brief the result type of this functor */
64 using result_type = R;
65 /*! \brief virtual destructor */
66 virtual ~StmtFunctor() {}
67 /*!
68 * \brief Same as call.
69 * \param n The stmt node.
70 * \param args Additional arguments.
71 * \return The result of the call
72 */
73 R operator()(const Stmt& n, Args... args) { return VisitStmt(n, std::forward<Args>(args)...); }
74 /*!
75 * \brief The functor call.
76 * \param n The stmt node.
77 * \param args Additional arguments.
78 * \return The result of the call
79 */
80 virtual R VisitStmt(const Stmt& n, Args... args) {
81 static FType vtable = InitVTable();
82 return vtable(n, this, std::forward<Args>(args)...);
83 }
84 // Functions that can be overriden by subclass
85 virtual R VisitStmt_(const LetStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
86 virtual R VisitStmt_(const AttrStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
87 virtual R VisitStmt_(const IfThenElseNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
88 virtual R VisitStmt_(const ForNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
89 virtual R VisitStmt_(const WhileNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
90 virtual R VisitStmt_(const AllocateNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
91 virtual R VisitStmt_(const AllocateConstNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
92 virtual R VisitStmt_(const DeclBufferNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
93 virtual R VisitStmt_(const StoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
94 virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
95 virtual R VisitStmt_(const BufferRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
96 virtual R VisitStmt_(const AssertStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
97 virtual R VisitStmt_(const ProducerStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
98 virtual R VisitStmt_(const ProducerRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
99 virtual R VisitStmt_(const PrefetchNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
100 virtual R VisitStmt_(const SeqStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
101 virtual R VisitStmt_(const EvaluateNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
102 virtual R VisitStmt_(const BlockNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
103 virtual R VisitStmt_(const BlockRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
104 virtual R VisitStmtDefault_(const Object* op, Args...) {
105 LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
106 }
107
108 private:
109 // initialize the vtable.
110 static FType InitVTable() {
111 FType vtable;
112 IR_STMT_FUNCTOR_DISPATCH(LetStmtNode);
113 IR_STMT_FUNCTOR_DISPATCH(AttrStmtNode);
114 IR_STMT_FUNCTOR_DISPATCH(IfThenElseNode);
115 IR_STMT_FUNCTOR_DISPATCH(ForNode);
116 IR_STMT_FUNCTOR_DISPATCH(WhileNode);
117 IR_STMT_FUNCTOR_DISPATCH(AllocateNode);
118 IR_STMT_FUNCTOR_DISPATCH(AllocateConstNode);
119 IR_STMT_FUNCTOR_DISPATCH(DeclBufferNode);
120 IR_STMT_FUNCTOR_DISPATCH(StoreNode);
121 IR_STMT_FUNCTOR_DISPATCH(AssertStmtNode);
122 IR_STMT_FUNCTOR_DISPATCH(ProducerStoreNode);
123 IR_STMT_FUNCTOR_DISPATCH(ProducerRealizeNode);
124 IR_STMT_FUNCTOR_DISPATCH(PrefetchNode);
125 IR_STMT_FUNCTOR_DISPATCH(SeqStmtNode);
126 IR_STMT_FUNCTOR_DISPATCH(EvaluateNode);
127 IR_STMT_FUNCTOR_DISPATCH(BufferStoreNode);
128 IR_STMT_FUNCTOR_DISPATCH(BufferRealizeNode);
129 IR_STMT_FUNCTOR_DISPATCH(BlockNode);
130 IR_STMT_FUNCTOR_DISPATCH(BlockRealizeNode);
131 return vtable;
132 }
133};
134
135#undef IR_STMT_FUNCTOR_DISPATCH
136#undef STMT_FUNCTOR_DEFAULT
137
138/*!
139 * \brief StmtVisitor.
140 */
141class TVM_DLL StmtVisitor : protected StmtFunctor<void(const Stmt&)> {
142 public:
143 using StmtFunctor::operator();
144
145 protected:
146 using StmtFunctor::VisitStmt;
147 /*!
148 * \brief Visitor to Exprs, can be overriden
149 * to do recursive changes to Exprs.
150 * \note A common pattern is to call ExprVisitor here,
151 * or have a class sub-class both StmtVisitor and ExprVisitor
152 * and redirect Visit to ExprMutator::VisitExpr(Expr)
153 */
154 virtual void VisitExpr(const PrimExpr& e) {}
155 // statement visitor
156 void VisitStmt_(const AttrStmtNode* op) override;
157 void VisitStmt_(const IfThenElseNode* op) override;
158 void VisitStmt_(const LetStmtNode* op) override;
159 void VisitStmt_(const ForNode* op) override;
160 void VisitStmt_(const WhileNode* op) override;
161 void VisitStmt_(const AllocateNode* op) override;
162 void VisitStmt_(const AllocateConstNode* op) override;
163 void VisitStmt_(const DeclBufferNode* op) override;
164 void VisitStmt_(const StoreNode* op) override;
165 void VisitStmt_(const BufferStoreNode* op) override;
166 void VisitStmt_(const BufferRealizeNode* op) override;
167 void VisitStmt_(const AssertStmtNode* op) override;
168 void VisitStmt_(const ProducerStoreNode* op) override;
169 void VisitStmt_(const ProducerRealizeNode* op) override;
170 void VisitStmt_(const PrefetchNode* op) override;
171 void VisitStmt_(const SeqStmtNode* op) override;
172 void VisitStmt_(const EvaluateNode* op) override;
173 void VisitStmt_(const BlockNode* op) override;
174 void VisitStmt_(const BlockRealizeNode* op) override;
175};
176
177/*!
178 * \brief StmtMutator that mutates the statements.
179 */
180class TVM_DLL StmtMutator : protected StmtFunctor<Stmt(const Stmt&)> {
181 public:
182 /*!
183 * \brief Mutate stmt.
184 * \param stmt The input statement to be mutated.
185 * \return The result of the call
186 * \note It is important that stmt is passed by value.
187 * so copy on write can be triggered correctly.
188 * do mutator(std::move(stmt)) or when copy elison is triggered.
189 */
190 Stmt operator()(Stmt stmt) {
191 allow_copy_on_write_ = true;
192 return VisitStmt(stmt);
193 }
194
195 protected:
196 // We perform copy on write optimizations on the StmtMutator
197 // so that an unique copy of parent can be mutated inplace
198 // when some of its children changed.
199 // We only do such optimization for Stmt nests(instead of Exprs) for now
200 // as Stmt's parent state is more likely remain unchanged when one of
201 // its child block changes.
202 /*!
203 * \brief Internal state to indicate whether copy on write is enabled.
204 * COW is enabled iff all the parents of the node are unique.
205 */
206 bool allow_copy_on_write_{false};
207 /*!
208 * \brief Perform copy on write on node.
209 *
210 * If CopyOnWrite is allowed, directly return
211 * a strong reference to the node container.
212 * Otherwise, return a copy of the node.
213 *
214 * \return The result object pointer.
215 */
216 template <typename TNode>
217 ObjectPtr<TNode> CopyOnWrite(const TNode* node) {
218 static_assert(std::is_base_of<StmtNode, TNode>::value,
219 "StmtMutator:: CopyOnWrite requires us to track uniqueness of all parent "
220 "nodes during the recursion. Because the child classes do not necessarily "
221 "check the Array, Expr and other structures during the visit, it is only safe to "
222 "call this function with StmtNodes for now. "
223 "Please create a new node directly in other cases.");
224 if (allow_copy_on_write_) {
225 // return the old node.
226 return runtime::GetObjectPtr<TNode>(const_cast<TNode*>(node));
227 } else {
228 // Make a new copy of the node.
229 // need to rely on the default copy constructor
230 return runtime::make_object<TNode>(*node);
231 }
232 }
233 /*!
234 * \brief Internal mutator that everyone calls.
235 * \note To override mutate's behavior, override VisitExpr instead.
236 * \param stmt The input stmt.
237 * \return The mutated results.
238 */
239 Stmt VisitStmt(const Stmt& stmt) override {
240 if (allow_copy_on_write_ && !stmt.unique()) {
241 allow_copy_on_write_ = false;
242 Stmt ret = StmtFunctor::VisitStmt(stmt);
243 allow_copy_on_write_ = true;
244 return ret;
245 } else {
246 return StmtFunctor::VisitStmt(stmt);
247 }
248 }
249 /*!
250 * \brief Visitor to Exprs, can be overriden
251 * to do recursive changes to Exprs.
252 * \note A common pattern is to call ExprMutator here,
253 * or have a class sub-class both StmtMutator and ExprMutator
254 * and redirect Mutate to ExprMutator::Mutate(Expr)
255 */
256 virtual PrimExpr VisitExpr(const PrimExpr& e) { return e; }
257 // statement visitor
258 Stmt VisitStmt_(const AttrStmtNode* op) override;
259 Stmt VisitStmt_(const IfThenElseNode* op) override;
260 Stmt VisitStmt_(const LetStmtNode* op) override;
261 Stmt VisitStmt_(const ForNode* op) override;
262 Stmt VisitStmt_(const WhileNode* op) override;
263 Stmt VisitStmt_(const AllocateNode* op) override;
264 Stmt VisitStmt_(const AllocateConstNode* op) override;
265 Stmt VisitStmt_(const DeclBufferNode* op) override;
266 Stmt VisitStmt_(const StoreNode* op) override;
267 Stmt VisitStmt_(const BufferStoreNode* op) override;
268 Stmt VisitStmt_(const BufferRealizeNode* op) override;
269 Stmt VisitStmt_(const AssertStmtNode* op) override;
270 Stmt VisitStmt_(const ProducerStoreNode* op) override;
271 Stmt VisitStmt_(const ProducerRealizeNode* op) override;
272 Stmt VisitStmt_(const PrefetchNode* op) override;
273 Stmt VisitStmt_(const SeqStmtNode* op) override;
274 Stmt VisitStmt_(const EvaluateNode* op) override;
275 Stmt VisitStmt_(const BlockNode* op) override;
276 Stmt VisitStmt_(const BlockRealizeNode* op) override;
277 /*!
278 * \brief Alternative advance method for SeqStmtNode.
279 *
280 * This function can be called when a child class override
281 * VisitStmt_(const SeqStmtNode*) to introduce
282 * the special behavior to visit
283 *
284 * \param op The sequence.
285 * \param flatten_before_visit Whether to flatten the sequence before visit.
286 * \param fmutate The mutate function, can be nullptr, which defaults to Visit.
287 * \return The mutated result.
288 */
289 Stmt VisitSeqStmt_(const SeqStmtNode* op, bool flatten_before_visit,
290 std::function<Stmt(const Stmt&)> fmutate = nullptr);
291
292 // internal helper.
293 class Internal;
294};
295
296/*!
297 * \brief Visitor that recursively visit stmts and exprs on them.
298 */
299class StmtExprVisitor : public StmtVisitor, public ExprVisitor {
300 public:
301 using StmtVisitor::operator();
302 using ExprVisitor::operator();
303
304 protected:
305 using ExprVisitor::VisitExpr;
306 using StmtVisitor::VisitStmt;
307
308 void VisitExpr(const PrimExpr& e) override { return ExprVisitor::VisitExpr(e); }
309};
310
311/*!
312 * \brief Mutator that recursively mutates stmts and exprs on them.
313 */
314class StmtExprMutator : public StmtMutator, public ExprMutator {
315 public:
316 using StmtMutator::operator();
317 using ExprMutator::operator();
318
319 protected:
320 using ExprMutator::VisitExpr;
321 using StmtMutator::VisitExpr;
322
323 PrimExpr VisitExpr(const PrimExpr& e) override { return ExprMutator::VisitExpr(e); }
324};
325
326/*!
327 * \brief recursively visit the ir nodes in post DFS order, and transform it
328 *
329 * \param stmt The ir to be transformed.
330 * \param preorder The function called in before recursive mutation
331 * If preorder returns None, then the transform will proceed to recursive call.
332 * If preorder returns a not None Stmt/Expr, the transformer will simply return it and
333 * won't do further recursion.
334 * \param postorder The function called after recursive mutation.
335 * The recursive mutation result is passed to postorder for further mutation.
336 * \param only_enable List of runtime::String.
337 * If it is null, all IRNode will call preorder/postorder
338 * If it is not null, preorder/postorder will only be called
339 * when the IRNode's type key is in the list.
340 */
341TVM_DLL Stmt IRTransform(Stmt stmt, const runtime::PackedFunc& preorder,
342 const runtime::PackedFunc& postorder,
343 Optional<Array<String>> only_enable = NullOpt);
344
345/*!
346 * \brief Recursively visit the ir in post DFS order node, apply fvisit
347 * Each node is guaranteed to be visited only once.
348 * \param node The ir to be visited.
349 * \param fvisit The visitor function to be applied.
350 */
351TVM_DLL void PostOrderVisit(const ObjectRef& node, std::function<void(const ObjectRef&)> fvisit);
352
353/*!
354 * \brief Substitute the var specified by vmap.
355 * \param stmt The source statement to be substituted
356 * \param vmap returns a new value if re-mapping is needed, otherwise returns nullptr.
357 * \return The converted form.
358 */
359TVM_DLL Stmt Substitute(Stmt stmt, std::function<Optional<PrimExpr>(const Var& var)> vmap);
360
361/*!
362 * \brief Substitute the var specified by vmap.
363 * \param expr The source statement to be substituted
364 * \param vmap returns a new value if re-mapping is needed, otherwise returns nullptr.
365 * \return The result.
366 */
367TVM_DLL PrimExpr Substitute(PrimExpr expr, std::function<Optional<PrimExpr>(const Var& var)> vmap);
368
369/*!
370 * \brief Substitute the var specified by vmap.
371 * \param region The object whose vars are to be substituted
372 * \param vmap The map of new values.
373 * \return The result.
374 */
375TVM_DLL Array<Range> Substitute(const Array<Range>& region, const Map<Var, PrimExpr>& vmap);
376
377/*!
378 * \brief Sugar for substitute via a given map.
379 * \param input The input to be updated.
380 * \param value_map The map of new values.
381 * \return The result.
382 * \tparam T the input type, can be PrimExpr or Stmt.
383 */
384template <typename T>
385inline auto Substitute(T input, const Map<Var, PrimExpr>& value_map) {
386 auto vmap = [&](const Var& var) -> Optional<PrimExpr> {
387 auto it = value_map.find(var);
388 if (it != value_map.end()) return (*it).second;
389 return Optional<PrimExpr>(nullptr);
390 };
391 return Substitute(std::move(input), vmap);
392}
393
394/*!
395 * \brief Sugar for substitute via a given map.
396 * \param input The input to be updated.
397 * \param value_map The map of new values.
398 * \return The result.
399 * \tparam T the input type, can be PrimExpr or Stmt.
400 */
401template <typename T>
402inline T Substitute(T input, const std::unordered_map<const VarNode*, PrimExpr>& value_map) {
403 auto vmap = [&](const Var& var) -> Optional<PrimExpr> {
404 auto it = value_map.find(var.get());
405 if (it != value_map.end()) return (*it).second;
406 return Optional<PrimExpr>(nullptr);
407 };
408 return Substitute(std::move(input), vmap);
409}
410
411/*!
412 * \brief Substitute the var specified by vmap and legalize data types after substitution.
413 * \param stmt The source statement to be substituted
414 * \param vmap returns a new value if re-mapping is needed, otherwise returns nullptr.
415 *
416 * Unlike `Substitute`, this allows the substitution to change the data type of the expression.
417 *
418 * \sa Substitute
419 * \return The result.
420 */
421TVM_DLL Stmt SubstituteWithDataTypeLegalization(Stmt stmt,
422 std::function<Optional<PrimExpr>(const Var&)> vmap);
423
424/*!
425 * \brief Substitute the var specified by vmap and legalize data types after substitution.
426 * \param expr The source statement to be substituted
427 * \param vmap returns a new value if re-mapping is needed, otherwise returns nullptr.
428 *
429 * Unlike `Substitute`, this allows the substitution to change the data type of the expression.
430 *
431 * \sa Substitute
432 * \return The result.
433 */
434TVM_DLL PrimExpr SubstituteWithDataTypeLegalization(
435 PrimExpr expr, std::function<Optional<PrimExpr>(const Var&)> vmap);
436
437/*!
438 * \brief Recursively visit the IR in pre DFS order node, apply fvisit.
439 * If fvisit returns false, it won't visit the children of the node.
440 * \param stmt_or_expr The ir to be visited.
441 * \param fvisit The visitor function to be applied. If fvisit returns false, it won't visit the
442 * children of the node
443 */
444TVM_DLL void PreOrderVisit(const ObjectRef& stmt_or_expr,
445 const std::function<bool(const ObjectRef&)>& fvisit);
446
447/*!
448 * \brief Renew the definition nodes for a TIR, including Var, Buffer and IterVar.
449 * This pass works as a simple DeepCopy to duplicate a function with different Vars and
450 * Buffers but the same behavior
451 * \param func The input PrimFunc.
452 * \return The renewed func.
453 */
454TVM_DLL PrimFunc RenewDefs(const PrimFunc& func);
455
456/*!
457 * \brief Check if the statement contains the specified node type.
458 *
459 * This utility potentially walks the entire statement, and should
460 * therefore not be used if it could otherwise be merged with another
461 * pass.
462 *
463 * \param stmt The statement to be searched
464 * \return Whether stmt contains Node
465 */
466template <typename Node, typename = std::enable_if_t<std::is_base_of_v<StmtNode, Node>>>
467bool ContainsNode(const Stmt& stmt) {
468 struct Visitor : StmtVisitor {
469 // Early bail-out, if we already found the node.
470 void VisitStmt(const Stmt& stmt) final {
471 if (contains_node) {
472 return;
473 }
474 StmtVisitor::VisitStmt(stmt);
475 }
476
477 void VisitStmt_(const Node* block) override { contains_node = true; }
478
479 bool contains_node{false};
480 };
481
482 Visitor visitor;
483 visitor(stmt);
484 return visitor.contains_node;
485}
486
487} // namespace tir
488} // namespace tvm
489
490#endif // TVM_TIR_STMT_FUNCTOR_H_
491