1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #ifndef TVM_TIR_SCHEDULE_TRANSFORM_H_ |
20 | #define TVM_TIR_SCHEDULE_TRANSFORM_H_ |
21 | |
22 | #include <tvm/tir/schedule/schedule.h> |
23 | #include <tvm/tir/schedule/state.h> |
24 | #include <tvm/tir/stmt_functor.h> |
25 | |
26 | #include <unordered_map> |
27 | #include <utility> |
28 | |
29 | #include "../../arith/ir_mutator_with_analyzer.h" |
30 | #include "../ir/functor_common.h" |
31 | |
32 | namespace tvm { |
33 | namespace tir { |
34 | |
35 | /******** Annotation ********/ |
36 | |
37 | /*! |
38 | * \brief Create a new block with the given annotation added |
39 | * \param block The block with original annotation |
40 | * \param attr_key The annotation key to be added |
41 | * \param attr_value The annotation value to be added |
42 | * \return A new block with the given annotation as its last annotation |
43 | */ |
44 | Block WithAnnotation(const BlockNode* block, const String& attr_key, const ObjectRef& attr_value); |
45 | |
46 | /******** Buffer Related ********/ |
47 | |
48 | /*! |
49 | * \brief Create a new buffer by changing the storage scope. |
50 | * \param buffer The given buffer. |
51 | * \param scope The target storage scope. |
52 | * \return The new buffer with target storage scope. |
53 | */ |
54 | Buffer WithScope(const Buffer& buffer, const String& scope); |
55 | |
56 | /*! |
57 | * \brief Replaces the buffer within the specific sequence of regions |
58 | * \param regions The regions whose buffers are to be replaced |
59 | * \param source The buffer to be replaced |
60 | * \param target The buffer to be replaced to |
61 | * \return The new sequence of regions after replacement |
62 | */ |
63 | Array<BufferRegion> ReplaceBuffer(Array<BufferRegion> regions, const Buffer& source, |
64 | const Buffer& target); |
65 | |
66 | /*! |
67 | * \brief Replaces the buffer within the specific sequence of match_buffers |
68 | * \param match_buffers The match_buffers whose buffers are to be replaced |
69 | * \param source The buffer to be replaced |
70 | * \param target The buffer to be replaced to |
71 | * \return The new sequence of match_buffers after replacement |
72 | */ |
73 | Array<MatchBufferRegion> ReplaceBuffer(Array<MatchBufferRegion> match_buffers, const Buffer& source, |
74 | const Buffer& target); |
75 | |
76 | /*! |
77 | * \brief Replaces the buffer region within the specific sequence of regions |
78 | * \param regions The regions to be replaced |
79 | * \param source_buffer The buffer to whose region is to be replaced |
80 | * \param target The buffer region to be replaced to |
81 | * \return The new sequence of regions after replacement |
82 | */ |
83 | Array<BufferRegion> ReplaceBufferRegion(Array<BufferRegion> regions, const Buffer& source_buffer, |
84 | const BufferRegion& target); |
85 | |
86 | /*! |
87 | * \brief Replaces the buffer region within the specific sequence of match_buffers |
88 | * \param regions The match_buffers to be replaced |
89 | * \param source_buffer The buffer to whose region is to be replaced |
90 | * \param target The buffer region to be replaced to |
91 | * \return The new sequence of match_buffers after replacement |
92 | */ |
93 | Array<MatchBufferRegion> ReplaceBufferRegion(Array<MatchBufferRegion> match_buffers, |
94 | const Buffer& source_buffer, |
95 | const BufferRegion& target); |
96 | |
97 | /*! |
98 | * \brief A helper mutator which recursively replaces the old buffer with the new buffer and |
99 | * collects the block sref reuse information for the following replacement. |
100 | * |
101 | * If the buffer to be replaced in used as the source in `match_buffers`, depending the specific |
102 | * use cases, the target buffers in `match_buffers` may also need to be mutated. In this |
103 | * case, this class should be subclassed to explicitly handle `match_buffers`. |
104 | */ |
105 | class ReplaceBufferMutator : public StmtExprMutator { |
106 | public: |
107 | /*! |
108 | * \brief The constructor |
109 | * \param old_buffer The old buffer |
110 | * \param new_buffer The new buffer |
111 | * \param block_sref_reuse Optional map to record mapping between old and new blocks that reuse |
112 | * sref. |
113 | */ |
114 | ReplaceBufferMutator(const Buffer& old_buffer, Buffer new_buffer, |
115 | Map<Block, Block>* block_sref_reuse); |
116 | |
117 | ReplaceBufferMutator(const Map<Buffer, Buffer>& buffer_map, Map<Block, Block>* block_sref_reuse); |
118 | |
119 | protected: |
120 | using StmtExprMutator::VisitExpr_; |
121 | using StmtExprMutator::VisitStmt_; |
122 | |
123 | PrimExpr VisitExpr_(const VarNode* var) final; |
124 | |
125 | template <typename Node> |
126 | Node VisitBufferAccess(Node node) { |
127 | auto it = buffer_var_map_.find(node->buffer->data.get()); |
128 | if (it != buffer_var_map_.end()) { |
129 | node.CopyOnWrite()->buffer = it->second; |
130 | } |
131 | return node; |
132 | } |
133 | |
134 | Stmt VisitStmt_(const BufferStoreNode* op) final; |
135 | |
136 | PrimExpr VisitExpr_(const BufferLoadNode* op) final; |
137 | |
138 | virtual MatchBufferRegion VisitMatchBufferRegion(const MatchBufferRegion& match_buffer); |
139 | |
140 | Stmt VisitStmt_(const BlockNode* block) override; |
141 | |
142 | /*! |
143 | * \brief A mapping which maps old buffer vars to new buffers, including the buffers defined in |
144 | * MatchBufferRegion. |
145 | */ |
146 | std::unordered_map<const VarNode*, Buffer> buffer_var_map_; |
147 | /*! \brief The block sref reuse map for the following replacement */ |
148 | Map<Block, Block>* block_sref_reuse_; |
149 | }; |
150 | |
151 | /******** Block Removal ********/ |
152 | |
153 | /*! |
154 | * \brief Construct a new AST, with a specific sref tree leaf removed. |
155 | * The leaf's ancestors who have only a single child will be removed too. |
156 | * \param leaf_block_sref The block/loop sref to the sref tree leaf to be removed |
157 | * \param src_stmt The root of the subtree where the replacement begins |
158 | * \param tgt_stmt The root of the subtree after the replacement |
159 | * \return A boolean indicating if the leaf can be removed successfully |
160 | * \note Read before use: |
161 | * 1) Removal is not conducted beyond scope-level. |
162 | * 2) This method only works properly when the scope root is a stage pipeline. |
163 | * |
164 | * An example of the removal plan, say we are removing the leaf block "B" from the AST. |
165 | * |
166 | * \code |
167 | * with block([], "scope_root"): |
168 | * ... |
169 | * with block([128, 128], "B") as [vi, vj]: |
170 | * B[vi, vj] = A[vi, vj] + 1.0 |
171 | * with block([128, 128], "C") as [vi, vj]: |
172 | * C[vi, vj] = B[vi, vj] * 2.0 |
173 | * \endcode |
174 | * |
175 | * Ths method does not mutate the AST, instead it returns the a `(src_stmt, tgt_stmt)` pair as a |
176 | * plan to substitute certain pieces of the IR. |
177 | * |
178 | * In our example, it returns block "scope_root" as `src_stmt`, and the result `tgt_stmt` is: |
179 | * |
180 | * \code |
181 | * with block([], "scope_root"): |
182 | * ... |
183 | * with block([128, 128], "C") as [vi, vj]: |
184 | * C[vi, vj] = B[vi, vj] * 2.0 |
185 | * \endcode |
186 | */ |
187 | void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_sref, |
188 | Stmt* src_stmt, Stmt* tgt_stmt); |
189 | |
190 | /*! |
191 | * \brief Tile a subset of loops in the block according to the given tensor intrinsic. |
192 | * \param self The schedule to which tiling is applied |
193 | * \param block_rv The block whose subset of loops will be tiled |
194 | * \param intrin_name The name of a tensor intrinsic, must be registerd via |
195 | * TensorIntrin.register(...) beforehand |
196 | * \param allow_padding Whether to allow padding when tiling |
197 | * \return LoopRV corresponding to the outermost loop of a |
198 | * block tiled according to the given intrin, NullOpt if a valid loop mapping is not found |
199 | */ |
200 | Optional<tir::LoopRV> TileWithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv, |
201 | const String& intrin_name, bool allow_padding = false); |
202 | |
203 | /******** Block mutation ********/ |
204 | |
205 | /*! |
206 | * \brief Simplifier for indices of buffer access and block buffer access regions. |
207 | */ |
208 | class BlockBufferAccessSimplifier : public arith::IRMutatorWithAnalyzer { |
209 | public: |
210 | /*! |
211 | * \brief Simplify indices of buffer access and block buffer access regions in the statement |
212 | * \param stmt The statement to be simplified |
213 | * \param analyzer The arithmetic analyzer |
214 | * \return The simplified statement |
215 | */ |
216 | static Stmt Simplify(const Stmt& stmt, arith::Analyzer* analyzer) { |
217 | BlockBufferAccessSimplifier simplifier(analyzer); |
218 | return simplifier(stmt); |
219 | } |
220 | |
221 | private: |
222 | explicit BlockBufferAccessSimplifier(arith::Analyzer* analyzer) |
223 | : IRMutatorWithAnalyzer(analyzer) {} |
224 | |
225 | using IRMutatorWithAnalyzer::VisitExpr_; |
226 | using IRMutatorWithAnalyzer::VisitStmt_; |
227 | |
228 | void SimplifyAccessRegion(Array<BufferRegion>* old_access_regions); |
229 | void SimplifyBufferIndices(Array<PrimExpr>* indices); |
230 | |
231 | Stmt VisitStmt_(const BlockNode* op) final; |
232 | Stmt VisitStmt_(const BufferStoreNode* op) final; |
233 | PrimExpr VisitExpr_(const BufferLoadNode* op) final; |
234 | }; |
235 | |
236 | } // namespace tir |
237 | } // namespace tvm |
238 | |
239 | #endif // TVM_TIR_SCHEDULE_TRANSFORM_H_ |
240 | |