1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#ifndef TVM_TIR_SCHEDULE_TRANSFORM_H_
20#define TVM_TIR_SCHEDULE_TRANSFORM_H_
21
22#include <tvm/tir/schedule/schedule.h>
23#include <tvm/tir/schedule/state.h>
24#include <tvm/tir/stmt_functor.h>
25
26#include <unordered_map>
27#include <utility>
28
29#include "../../arith/ir_mutator_with_analyzer.h"
30#include "../ir/functor_common.h"
31
32namespace tvm {
33namespace tir {
34
35/******** Annotation ********/
36
37/*!
38 * \brief Create a new block with the given annotation added
39 * \param block The block with original annotation
40 * \param attr_key The annotation key to be added
41 * \param attr_value The annotation value to be added
42 * \return A new block with the given annotation as its last annotation
43 */
44Block WithAnnotation(const BlockNode* block, const String& attr_key, const ObjectRef& attr_value);
45
46/******** Buffer Related ********/
47
48/*!
49 * \brief Create a new buffer by changing the storage scope.
50 * \param buffer The given buffer.
51 * \param scope The target storage scope.
52 * \return The new buffer with target storage scope.
53 */
54Buffer WithScope(const Buffer& buffer, const String& scope);
55
56/*!
57 * \brief Replaces the buffer within the specific sequence of regions
58 * \param regions The regions whose buffers are to be replaced
59 * \param source The buffer to be replaced
60 * \param target The buffer to be replaced to
61 * \return The new sequence of regions after replacement
62 */
63Array<BufferRegion> ReplaceBuffer(Array<BufferRegion> regions, const Buffer& source,
64 const Buffer& target);
65
66/*!
67 * \brief Replaces the buffer within the specific sequence of match_buffers
68 * \param match_buffers The match_buffers whose buffers are to be replaced
69 * \param source The buffer to be replaced
70 * \param target The buffer to be replaced to
71 * \return The new sequence of match_buffers after replacement
72 */
73Array<MatchBufferRegion> ReplaceBuffer(Array<MatchBufferRegion> match_buffers, const Buffer& source,
74 const Buffer& target);
75
76/*!
77 * \brief Replaces the buffer region within the specific sequence of regions
78 * \param regions The regions to be replaced
79 * \param source_buffer The buffer to whose region is to be replaced
80 * \param target The buffer region to be replaced to
81 * \return The new sequence of regions after replacement
82 */
83Array<BufferRegion> ReplaceBufferRegion(Array<BufferRegion> regions, const Buffer& source_buffer,
84 const BufferRegion& target);
85
86/*!
87 * \brief Replaces the buffer region within the specific sequence of match_buffers
88 * \param regions The match_buffers to be replaced
89 * \param source_buffer The buffer to whose region is to be replaced
90 * \param target The buffer region to be replaced to
91 * \return The new sequence of match_buffers after replacement
92 */
93Array<MatchBufferRegion> ReplaceBufferRegion(Array<MatchBufferRegion> match_buffers,
94 const Buffer& source_buffer,
95 const BufferRegion& target);
96
97/*!
98 * \brief A helper mutator which recursively replaces the old buffer with the new buffer and
99 * collects the block sref reuse information for the following replacement.
100 *
101 * If the buffer to be replaced in used as the source in `match_buffers`, depending the specific
102 * use cases, the target buffers in `match_buffers` may also need to be mutated. In this
103 * case, this class should be subclassed to explicitly handle `match_buffers`.
104 */
105class ReplaceBufferMutator : public StmtExprMutator {
106 public:
107 /*!
108 * \brief The constructor
109 * \param old_buffer The old buffer
110 * \param new_buffer The new buffer
111 * \param block_sref_reuse Optional map to record mapping between old and new blocks that reuse
112 * sref.
113 */
114 ReplaceBufferMutator(const Buffer& old_buffer, Buffer new_buffer,
115 Map<Block, Block>* block_sref_reuse);
116
117 ReplaceBufferMutator(const Map<Buffer, Buffer>& buffer_map, Map<Block, Block>* block_sref_reuse);
118
119 protected:
120 using StmtExprMutator::VisitExpr_;
121 using StmtExprMutator::VisitStmt_;
122
123 PrimExpr VisitExpr_(const VarNode* var) final;
124
125 template <typename Node>
126 Node VisitBufferAccess(Node node) {
127 auto it = buffer_var_map_.find(node->buffer->data.get());
128 if (it != buffer_var_map_.end()) {
129 node.CopyOnWrite()->buffer = it->second;
130 }
131 return node;
132 }
133
134 Stmt VisitStmt_(const BufferStoreNode* op) final;
135
136 PrimExpr VisitExpr_(const BufferLoadNode* op) final;
137
138 virtual MatchBufferRegion VisitMatchBufferRegion(const MatchBufferRegion& match_buffer);
139
140 Stmt VisitStmt_(const BlockNode* block) override;
141
142 /*!
143 * \brief A mapping which maps old buffer vars to new buffers, including the buffers defined in
144 * MatchBufferRegion.
145 */
146 std::unordered_map<const VarNode*, Buffer> buffer_var_map_;
147 /*! \brief The block sref reuse map for the following replacement */
148 Map<Block, Block>* block_sref_reuse_;
149};
150
151/******** Block Removal ********/
152
153/*!
154 * \brief Construct a new AST, with a specific sref tree leaf removed.
155 * The leaf's ancestors who have only a single child will be removed too.
156 * \param leaf_block_sref The block/loop sref to the sref tree leaf to be removed
157 * \param src_stmt The root of the subtree where the replacement begins
158 * \param tgt_stmt The root of the subtree after the replacement
159 * \return A boolean indicating if the leaf can be removed successfully
160 * \note Read before use:
161 * 1) Removal is not conducted beyond scope-level.
162 * 2) This method only works properly when the scope root is a stage pipeline.
163 *
164 * An example of the removal plan, say we are removing the leaf block "B" from the AST.
165 *
166 * \code
167 * with block([], "scope_root"):
168 * ...
169 * with block([128, 128], "B") as [vi, vj]:
170 * B[vi, vj] = A[vi, vj] + 1.0
171 * with block([128, 128], "C") as [vi, vj]:
172 * C[vi, vj] = B[vi, vj] * 2.0
173 * \endcode
174 *
175 * Ths method does not mutate the AST, instead it returns the a `(src_stmt, tgt_stmt)` pair as a
176 * plan to substitute certain pieces of the IR.
177 *
178 * In our example, it returns block "scope_root" as `src_stmt`, and the result `tgt_stmt` is:
179 *
180 * \code
181 * with block([], "scope_root"):
182 * ...
183 * with block([128, 128], "C") as [vi, vj]:
184 * C[vi, vj] = B[vi, vj] * 2.0
185 * \endcode
186 */
187void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_sref,
188 Stmt* src_stmt, Stmt* tgt_stmt);
189
190/*!
191 * \brief Tile a subset of loops in the block according to the given tensor intrinsic.
192 * \param self The schedule to which tiling is applied
193 * \param block_rv The block whose subset of loops will be tiled
194 * \param intrin_name The name of a tensor intrinsic, must be registerd via
195 * TensorIntrin.register(...) beforehand
196 * \param allow_padding Whether to allow padding when tiling
197 * \return LoopRV corresponding to the outermost loop of a
198 * block tiled according to the given intrin, NullOpt if a valid loop mapping is not found
199 */
200Optional<tir::LoopRV> TileWithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv,
201 const String& intrin_name, bool allow_padding = false);
202
203/******** Block mutation ********/
204
205/*!
206 * \brief Simplifier for indices of buffer access and block buffer access regions.
207 */
208class BlockBufferAccessSimplifier : public arith::IRMutatorWithAnalyzer {
209 public:
210 /*!
211 * \brief Simplify indices of buffer access and block buffer access regions in the statement
212 * \param stmt The statement to be simplified
213 * \param analyzer The arithmetic analyzer
214 * \return The simplified statement
215 */
216 static Stmt Simplify(const Stmt& stmt, arith::Analyzer* analyzer) {
217 BlockBufferAccessSimplifier simplifier(analyzer);
218 return simplifier(stmt);
219 }
220
221 private:
222 explicit BlockBufferAccessSimplifier(arith::Analyzer* analyzer)
223 : IRMutatorWithAnalyzer(analyzer) {}
224
225 using IRMutatorWithAnalyzer::VisitExpr_;
226 using IRMutatorWithAnalyzer::VisitStmt_;
227
228 void SimplifyAccessRegion(Array<BufferRegion>* old_access_regions);
229 void SimplifyBufferIndices(Array<PrimExpr>* indices);
230
231 Stmt VisitStmt_(const BlockNode* op) final;
232 Stmt VisitStmt_(const BufferStoreNode* op) final;
233 PrimExpr VisitExpr_(const BufferLoadNode* op) final;
234};
235
236} // namespace tir
237} // namespace tvm
238
239#endif // TVM_TIR_SCHEDULE_TRANSFORM_H_
240