1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file manifest_shared_memroy_local_stage.cc |
22 | * \brief Add the explicit local stage for the shared memory access on GPU. |
23 | * |
24 | * This pass finds the cache_read stage on the shared memory, and create another intermediate stage |
25 | * to store the data into local memory first, and then copy the data from local memory to the shared |
26 | * memory. This is similar to the schedule primitive cache_read, but it bypasses the limitation |
27 | * of requiring buffer access to be contiguous in each dimension. |
28 | */ |
29 | #include <tvm/arith/analyzer.h> |
30 | #include <tvm/tir/expr.h> |
31 | #include <tvm/tir/op.h> |
32 | #include <tvm/tir/stmt_functor.h> |
33 | #include <tvm/tir/transform.h> |
34 | |
35 | #include <unordered_set> |
36 | |
37 | #include "../../runtime/thread_storage_scope.h" |
38 | #include "../schedule/transform.h" |
39 | #include "tvm/tir/stmt.h" |
40 | |
41 | namespace tvm { |
42 | namespace tir { |
43 | |
44 | /*! \brief Rewriter for the block storing to the target buffer. Create an intermediate cache stage |
45 | * to store the result. Rewrite the original block to load from the intermediate buffer. |
46 | */ |
47 | class IntermediateStageRewriter { |
48 | public: |
49 | explicit IntermediateStageRewriter(const std::vector<Stmt>& ancestor_loop_or_blocks) |
50 | : ancestor_loop_or_blocks_(ancestor_loop_or_blocks) {} |
51 | |
52 | std::tuple<Buffer, Buffer, Block, Stmt> Rewrite(const BlockNode* block) { |
53 | const BufferStoreNode* store = block->body.as<BufferStoreNode>(); |
54 | CHECK(store != nullptr && runtime::StorageScope::Create(store->buffer.scope()).rank == |
55 | runtime::StorageRank::kShared) |
56 | << "ValueError: Expect the body of the block to be BufferStore to shared memory." ; |
57 | |
58 | const Buffer& target_buffer = store->buffer; |
59 | |
60 | // Step 0: Collect relaxed loops |
61 | std::vector<const ForNode*> relaxed_loops = CollectRelaxedOuterLoops(block, target_buffer); |
62 | |
63 | // Step 1: Create buffer for the local stage |
64 | auto [new_buffer, buffer_indices] = CreateIntermediateBuffer(relaxed_loops, target_buffer); |
65 | |
66 | // Step 2: Create the local stage block |
67 | Stmt local_stage = MakeLocalStage(block, new_buffer, buffer_indices, relaxed_loops, store); |
68 | |
69 | // Step 3: Create BufferLoad from the intermediate buffer |
70 | BufferLoad new_buffer_load = BufferLoad(new_buffer, buffer_indices); |
71 | BufferStore new_buffer_store = Downcast<BufferStore>(block->body); |
72 | new_buffer_store.CopyOnWrite()->value = new_buffer_load; |
73 | Block new_block = GetRef<Block>(block); |
74 | new_block.CopyOnWrite()->body = std::move(new_buffer_store); |
75 | |
76 | return {target_buffer, new_buffer, new_block, local_stage}; |
77 | } |
78 | |
79 | private: |
80 | /*! \brief Collect relaxed outer loops from innermost to outermost */ |
81 | std::vector<const ForNode*> CollectRelaxedOuterLoops(const BlockNode* block, |
82 | const Buffer& target_buffer) { |
83 | std::vector<const ForNode*> relaxed_loops; |
84 | for (int n = static_cast<int>(ancestor_loop_or_blocks_.size()) - 1, i = n - 1; i >= 0; --i) { |
85 | const Stmt& ancestor = ancestor_loop_or_blocks_[i]; |
86 | if (const ForNode* ancestor_loop = ancestor.as<ForNode>()) { |
87 | CHECK(ancestor_loop->kind == ForKind::kSerial || |
88 | ancestor_loop->kind == ForKind::kVectorized) |
89 | << "ValueError: Expect the ancestor loops to be serial or vectorized, got " |
90 | << ancestor_loop->kind; |
91 | relaxed_loops.push_back(ancestor.as<ForNode>()); |
92 | |
93 | if (i < n - 1) { |
94 | CHECK(ancestor_loop->body.same_as(ancestor_loop_or_blocks_[i + 1])) |
95 | << "ValueError: Expect the ancestor loops to have a single child." ; |
96 | } else { |
97 | const BlockRealizeNode* block_realize = ancestor_loop->body.as<BlockRealizeNode>(); |
98 | ICHECK(block_realize != nullptr); |
99 | CHECK(block_realize != nullptr && block_realize->block.get() == block) |
100 | << "ValueError: Expect the ancestor loops to have a single child." ; |
101 | } |
102 | } else { |
103 | const BlockRealizeNode* ancestor_block_realize = ancestor.as<BlockRealizeNode>(); |
104 | ICHECK(ancestor_block_realize != nullptr); |
105 | const BlockNode* ancestor_block = ancestor_block_realize->block.get(); |
106 | auto it = std::find_if( |
107 | ancestor_block->alloc_buffers.begin(), ancestor_block->alloc_buffers.end(), |
108 | [&target_buffer](const Buffer& buffer) { return buffer.same_as(target_buffer); }); |
109 | CHECK(it != ancestor_block->alloc_buffers.end()) |
110 | << "ValueError: Expect the shared memory allocation to be in the parent block." ; |
111 | break; |
112 | } |
113 | } |
114 | return relaxed_loops; |
115 | } |
116 | |
117 | /*! \brief Create the intermediate stage. */ |
118 | Stmt MakeLocalStage(const BlockNode* block, const Buffer& new_buffer, |
119 | Array<PrimExpr> local_stage_indices, |
120 | std::vector<const ForNode*> relaxed_loops, const BufferStoreNode* store) { |
121 | // Step 0: Create the body of the local stage, which is BufferStore to the intermediate buffer. |
122 | Stmt local_stage = BufferStore(new_buffer, store->value, local_stage_indices); |
123 | |
124 | // Step 1: Make block and block realize |
125 | BufferRegion write_buffer_region = BufferRegion::FromPoint(new_buffer, local_stage_indices); |
126 | local_stage = |
127 | Block(/*iter_vars=*/{}, /*reads=*/block->reads, /*writes=*/{write_buffer_region}, "" , |
128 | /*body=*/std::move(local_stage)); |
129 | local_stage = BlockRealize( |
130 | /*iter_values=*/{}, |
131 | /*predicate=*/ancestor_loop_or_blocks_.back().as<BlockRealizeNode>()->predicate, |
132 | Downcast<Block>(local_stage)); |
133 | |
134 | // Step 2: Add outer loops |
135 | Map<Var, PrimExpr> subst_map; |
136 | for (const ForNode* relaxed_loop : relaxed_loops) { |
137 | ObjectPtr<ForNode> for_node = make_object<ForNode>(*relaxed_loop); |
138 | for_node->loop_var = for_node->loop_var.copy_with_suffix("" ); |
139 | for_node->body = std::move(local_stage); |
140 | local_stage = For(for_node); |
141 | subst_map.Set(relaxed_loop->loop_var, for_node->loop_var); |
142 | } |
143 | local_stage = Substitute(local_stage, subst_map); |
144 | return local_stage; |
145 | } |
146 | |
147 | /*! \brief Create the intermediate buffer with the extents of the relaxed outer loops. */ |
148 | std::pair<Buffer, Array<PrimExpr>> CreateIntermediateBuffer( |
149 | const std::vector<const ForNode*> relaxed_loops, const Buffer& buffer) const { |
150 | Array<PrimExpr> buffer_indices; |
151 | Array<PrimExpr> new_buffer_shape; |
152 | |
153 | // Create the intermediate buffer for the local stage. The shape of the new buffer is the |
154 | // extents of the relaxed outer loops. |
155 | |
156 | for (auto it = relaxed_loops.rbegin(); it != relaxed_loops.rend(); ++it) { |
157 | const ForNode* relaxed_loop = *it; |
158 | buffer_indices.push_back(relaxed_loop->min + relaxed_loop->loop_var); |
159 | new_buffer_shape.push_back(relaxed_loop->extent); |
160 | } |
161 | Buffer new_buffer = WithScope(buffer, "local" ); |
162 | new_buffer.CopyOnWrite()->shape = new_buffer_shape; |
163 | return {new_buffer, buffer_indices}; |
164 | } |
165 | |
166 | const std::vector<Stmt>& ancestor_loop_or_blocks_; |
167 | }; |
168 | |
169 | class SharedMemoryLocalStageInserter : public StmtMutator { |
170 | public: |
171 | Stmt VisitStmt_(const ForNode* op) final { |
172 | ancestor_loop_or_blocks_.push_back(GetRef<Stmt>(op)); |
173 | Stmt new_stmt = StmtMutator::VisitStmt_(op); |
174 | ancestor_loop_or_blocks_.pop_back(); |
175 | return new_stmt; |
176 | } |
177 | |
178 | Stmt VisitStmt_(const BlockRealizeNode* op) final { |
179 | ancestor_loop_or_blocks_.push_back(GetRef<Stmt>(op)); |
180 | Stmt new_stmt = StmtMutator::VisitStmt_(op); |
181 | ancestor_loop_or_blocks_.pop_back(); |
182 | return new_stmt; |
183 | } |
184 | |
185 | Stmt VisitStmt_(const BlockNode* op) final { |
186 | if (op->annotations.count(attr::manifest_shared_memory_local_stage)) { |
187 | // Rewrite the shared memory access to load from the intermediate buffer. |
188 | // The annotated block must be a leaf block (will be checked during rewriting). No need to |
189 | // visit its body recursively. |
190 | |
191 | IntermediateStageRewriter rewriter(ancestor_loop_or_blocks_); |
192 | auto [target_buffer, new_buffer, new_block, local_stage] = rewriter.Rewrite(op); |
193 | buffer_remap_.Set(target_buffer, new_buffer); |
194 | |
195 | new_block.CopyOnWrite()->annotations.erase(attr::manifest_shared_memory_local_stage); |
196 | buffer_local_stage_.Set(target_buffer, local_stage); |
197 | target_buffers_.push_back(target_buffer); |
198 | |
199 | return std::move(new_block); |
200 | } |
201 | |
202 | std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> allocated_buffers( |
203 | op->alloc_buffers.begin(), op->alloc_buffers.end()); |
204 | |
205 | // Visit children and insert local stages (if any) to the proper location. |
206 | Array<Buffer> new_alloc_buffers; |
207 | Array<Stmt> new_seq; |
208 | |
209 | // Helper function to check if the subtree (body of the block) contains any target buffers. |
210 | // If so, the allocated intermediate buffer and the local stage should be lifted to the current |
211 | // block. |
212 | auto f_check_subtree = [&](int start, int end) { |
213 | for (int i = start; i < end; ++i) { |
214 | const Buffer& buffer = target_buffers_[i]; |
215 | if (allocated_buffers.count(buffer)) { |
216 | new_seq.push_back(buffer_local_stage_.at(buffer)); |
217 | new_alloc_buffers.push_back(buffer_remap_.at(buffer)); |
218 | } |
219 | } |
220 | }; |
221 | |
222 | if (const SeqStmtNode* seq = op->body.as<SeqStmtNode>()) { |
223 | // Visit each element of the SeqStmt. Create a new SeqStmt if any of the children is modified. |
224 | bool changed = false; // whether the SeqStmt has been changed |
225 | for (int i = 0, n = seq->seq.size(); i < n; ++i) { |
226 | int subtree_start = target_buffers_.size(); |
227 | Stmt new_seq_elem = VisitStmt(seq->seq[i]); |
228 | int subtree_end = target_buffers_.size(); |
229 | f_check_subtree(subtree_start, subtree_end); |
230 | new_seq.push_back(new_seq_elem); |
231 | if (!new_seq_elem.same_as(seq->seq[i])) { |
232 | changed = true; |
233 | } |
234 | } |
235 | if (!changed) { |
236 | return GetRef<Stmt>(op); |
237 | } |
238 | } else { |
239 | int subtree_start = target_buffers_.size(); |
240 | Stmt body = VisitStmt(op->body); |
241 | int subtree_end = target_buffers_.size(); |
242 | f_check_subtree(subtree_start, subtree_end); |
243 | if (body.same_as(op->body)) { |
244 | return GetRef<Stmt>(op); |
245 | } |
246 | new_seq.push_back(body); |
247 | } |
248 | |
249 | Block new_block = GetRef<Block>(op); |
250 | BlockNode* new_block_node = new_block.CopyOnWrite(); |
251 | // Add new buffer allocations if any. |
252 | if (new_alloc_buffers.size() > 0) { |
253 | new_block_node->alloc_buffers = Concat(new_block_node->alloc_buffers, new_alloc_buffers); |
254 | } |
255 | new_block_node->body = new_seq.size() == 1 ? new_seq[0] : SeqStmt(new_seq); |
256 | return std::move(new_block); |
257 | } |
258 | |
259 | std::vector<Stmt> ancestor_loop_or_blocks_; // ancestor loops or block realize |
260 | Map<Buffer, Buffer> buffer_remap_; // mapping from the target buffer to the intermediate buffer |
261 | Map<Buffer, Stmt> buffer_local_stage_; // mapping from the target buffer to the local stage |
262 | Array<Buffer> target_buffers_; // the target buffers for rewriting |
263 | }; |
264 | |
265 | namespace transform { |
266 | |
267 | Pass ManifestSharedMemoryLocalStage() { |
268 | auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { |
269 | auto* n = f.CopyOnWrite(); |
270 | n->body = SharedMemoryLocalStageInserter()(std::move(n->body)); |
271 | return f; |
272 | }; |
273 | return CreatePrimFuncPass(pass_func, 0, "tir.ManifestSharedMemoryLocalStage" , {}); |
274 | } |
275 | |
276 | TVM_REGISTER_GLOBAL("tir.transform.ManifestSharedMemoryLocalStage" ) |
277 | .set_body_typed(ManifestSharedMemoryLocalStage); |
278 | |
279 | } // namespace transform |
280 | } // namespace tir |
281 | } // namespace tvm |
282 | |