1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file manifest_shared_memroy_local_stage.cc
22 * \brief Add the explicit local stage for the shared memory access on GPU.
23 *
24 * This pass finds the cache_read stage on the shared memory, and create another intermediate stage
25 * to store the data into local memory first, and then copy the data from local memory to the shared
26 * memory. This is similar to the schedule primitive cache_read, but it bypasses the limitation
27 * of requiring buffer access to be contiguous in each dimension.
28 */
29#include <tvm/arith/analyzer.h>
30#include <tvm/tir/expr.h>
31#include <tvm/tir/op.h>
32#include <tvm/tir/stmt_functor.h>
33#include <tvm/tir/transform.h>
34
35#include <unordered_set>
36
37#include "../../runtime/thread_storage_scope.h"
38#include "../schedule/transform.h"
39#include "tvm/tir/stmt.h"
40
41namespace tvm {
42namespace tir {
43
44/*! \brief Rewriter for the block storing to the target buffer. Create an intermediate cache stage
45 * to store the result. Rewrite the original block to load from the intermediate buffer.
46 */
47class IntermediateStageRewriter {
48 public:
49 explicit IntermediateStageRewriter(const std::vector<Stmt>& ancestor_loop_or_blocks)
50 : ancestor_loop_or_blocks_(ancestor_loop_or_blocks) {}
51
52 std::tuple<Buffer, Buffer, Block, Stmt> Rewrite(const BlockNode* block) {
53 const BufferStoreNode* store = block->body.as<BufferStoreNode>();
54 CHECK(store != nullptr && runtime::StorageScope::Create(store->buffer.scope()).rank ==
55 runtime::StorageRank::kShared)
56 << "ValueError: Expect the body of the block to be BufferStore to shared memory.";
57
58 const Buffer& target_buffer = store->buffer;
59
60 // Step 0: Collect relaxed loops
61 std::vector<const ForNode*> relaxed_loops = CollectRelaxedOuterLoops(block, target_buffer);
62
63 // Step 1: Create buffer for the local stage
64 auto [new_buffer, buffer_indices] = CreateIntermediateBuffer(relaxed_loops, target_buffer);
65
66 // Step 2: Create the local stage block
67 Stmt local_stage = MakeLocalStage(block, new_buffer, buffer_indices, relaxed_loops, store);
68
69 // Step 3: Create BufferLoad from the intermediate buffer
70 BufferLoad new_buffer_load = BufferLoad(new_buffer, buffer_indices);
71 BufferStore new_buffer_store = Downcast<BufferStore>(block->body);
72 new_buffer_store.CopyOnWrite()->value = new_buffer_load;
73 Block new_block = GetRef<Block>(block);
74 new_block.CopyOnWrite()->body = std::move(new_buffer_store);
75
76 return {target_buffer, new_buffer, new_block, local_stage};
77 }
78
79 private:
80 /*! \brief Collect relaxed outer loops from innermost to outermost */
81 std::vector<const ForNode*> CollectRelaxedOuterLoops(const BlockNode* block,
82 const Buffer& target_buffer) {
83 std::vector<const ForNode*> relaxed_loops;
84 for (int n = static_cast<int>(ancestor_loop_or_blocks_.size()) - 1, i = n - 1; i >= 0; --i) {
85 const Stmt& ancestor = ancestor_loop_or_blocks_[i];
86 if (const ForNode* ancestor_loop = ancestor.as<ForNode>()) {
87 CHECK(ancestor_loop->kind == ForKind::kSerial ||
88 ancestor_loop->kind == ForKind::kVectorized)
89 << "ValueError: Expect the ancestor loops to be serial or vectorized, got "
90 << ancestor_loop->kind;
91 relaxed_loops.push_back(ancestor.as<ForNode>());
92
93 if (i < n - 1) {
94 CHECK(ancestor_loop->body.same_as(ancestor_loop_or_blocks_[i + 1]))
95 << "ValueError: Expect the ancestor loops to have a single child.";
96 } else {
97 const BlockRealizeNode* block_realize = ancestor_loop->body.as<BlockRealizeNode>();
98 ICHECK(block_realize != nullptr);
99 CHECK(block_realize != nullptr && block_realize->block.get() == block)
100 << "ValueError: Expect the ancestor loops to have a single child.";
101 }
102 } else {
103 const BlockRealizeNode* ancestor_block_realize = ancestor.as<BlockRealizeNode>();
104 ICHECK(ancestor_block_realize != nullptr);
105 const BlockNode* ancestor_block = ancestor_block_realize->block.get();
106 auto it = std::find_if(
107 ancestor_block->alloc_buffers.begin(), ancestor_block->alloc_buffers.end(),
108 [&target_buffer](const Buffer& buffer) { return buffer.same_as(target_buffer); });
109 CHECK(it != ancestor_block->alloc_buffers.end())
110 << "ValueError: Expect the shared memory allocation to be in the parent block.";
111 break;
112 }
113 }
114 return relaxed_loops;
115 }
116
117 /*! \brief Create the intermediate stage. */
118 Stmt MakeLocalStage(const BlockNode* block, const Buffer& new_buffer,
119 Array<PrimExpr> local_stage_indices,
120 std::vector<const ForNode*> relaxed_loops, const BufferStoreNode* store) {
121 // Step 0: Create the body of the local stage, which is BufferStore to the intermediate buffer.
122 Stmt local_stage = BufferStore(new_buffer, store->value, local_stage_indices);
123
124 // Step 1: Make block and block realize
125 BufferRegion write_buffer_region = BufferRegion::FromPoint(new_buffer, local_stage_indices);
126 local_stage =
127 Block(/*iter_vars=*/{}, /*reads=*/block->reads, /*writes=*/{write_buffer_region}, "",
128 /*body=*/std::move(local_stage));
129 local_stage = BlockRealize(
130 /*iter_values=*/{},
131 /*predicate=*/ancestor_loop_or_blocks_.back().as<BlockRealizeNode>()->predicate,
132 Downcast<Block>(local_stage));
133
134 // Step 2: Add outer loops
135 Map<Var, PrimExpr> subst_map;
136 for (const ForNode* relaxed_loop : relaxed_loops) {
137 ObjectPtr<ForNode> for_node = make_object<ForNode>(*relaxed_loop);
138 for_node->loop_var = for_node->loop_var.copy_with_suffix("");
139 for_node->body = std::move(local_stage);
140 local_stage = For(for_node);
141 subst_map.Set(relaxed_loop->loop_var, for_node->loop_var);
142 }
143 local_stage = Substitute(local_stage, subst_map);
144 return local_stage;
145 }
146
147 /*! \brief Create the intermediate buffer with the extents of the relaxed outer loops. */
148 std::pair<Buffer, Array<PrimExpr>> CreateIntermediateBuffer(
149 const std::vector<const ForNode*> relaxed_loops, const Buffer& buffer) const {
150 Array<PrimExpr> buffer_indices;
151 Array<PrimExpr> new_buffer_shape;
152
153 // Create the intermediate buffer for the local stage. The shape of the new buffer is the
154 // extents of the relaxed outer loops.
155
156 for (auto it = relaxed_loops.rbegin(); it != relaxed_loops.rend(); ++it) {
157 const ForNode* relaxed_loop = *it;
158 buffer_indices.push_back(relaxed_loop->min + relaxed_loop->loop_var);
159 new_buffer_shape.push_back(relaxed_loop->extent);
160 }
161 Buffer new_buffer = WithScope(buffer, "local");
162 new_buffer.CopyOnWrite()->shape = new_buffer_shape;
163 return {new_buffer, buffer_indices};
164 }
165
166 const std::vector<Stmt>& ancestor_loop_or_blocks_;
167};
168
169class SharedMemoryLocalStageInserter : public StmtMutator {
170 public:
171 Stmt VisitStmt_(const ForNode* op) final {
172 ancestor_loop_or_blocks_.push_back(GetRef<Stmt>(op));
173 Stmt new_stmt = StmtMutator::VisitStmt_(op);
174 ancestor_loop_or_blocks_.pop_back();
175 return new_stmt;
176 }
177
178 Stmt VisitStmt_(const BlockRealizeNode* op) final {
179 ancestor_loop_or_blocks_.push_back(GetRef<Stmt>(op));
180 Stmt new_stmt = StmtMutator::VisitStmt_(op);
181 ancestor_loop_or_blocks_.pop_back();
182 return new_stmt;
183 }
184
185 Stmt VisitStmt_(const BlockNode* op) final {
186 if (op->annotations.count(attr::manifest_shared_memory_local_stage)) {
187 // Rewrite the shared memory access to load from the intermediate buffer.
188 // The annotated block must be a leaf block (will be checked during rewriting). No need to
189 // visit its body recursively.
190
191 IntermediateStageRewriter rewriter(ancestor_loop_or_blocks_);
192 auto [target_buffer, new_buffer, new_block, local_stage] = rewriter.Rewrite(op);
193 buffer_remap_.Set(target_buffer, new_buffer);
194
195 new_block.CopyOnWrite()->annotations.erase(attr::manifest_shared_memory_local_stage);
196 buffer_local_stage_.Set(target_buffer, local_stage);
197 target_buffers_.push_back(target_buffer);
198
199 return std::move(new_block);
200 }
201
202 std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> allocated_buffers(
203 op->alloc_buffers.begin(), op->alloc_buffers.end());
204
205 // Visit children and insert local stages (if any) to the proper location.
206 Array<Buffer> new_alloc_buffers;
207 Array<Stmt> new_seq;
208
209 // Helper function to check if the subtree (body of the block) contains any target buffers.
210 // If so, the allocated intermediate buffer and the local stage should be lifted to the current
211 // block.
212 auto f_check_subtree = [&](int start, int end) {
213 for (int i = start; i < end; ++i) {
214 const Buffer& buffer = target_buffers_[i];
215 if (allocated_buffers.count(buffer)) {
216 new_seq.push_back(buffer_local_stage_.at(buffer));
217 new_alloc_buffers.push_back(buffer_remap_.at(buffer));
218 }
219 }
220 };
221
222 if (const SeqStmtNode* seq = op->body.as<SeqStmtNode>()) {
223 // Visit each element of the SeqStmt. Create a new SeqStmt if any of the children is modified.
224 bool changed = false; // whether the SeqStmt has been changed
225 for (int i = 0, n = seq->seq.size(); i < n; ++i) {
226 int subtree_start = target_buffers_.size();
227 Stmt new_seq_elem = VisitStmt(seq->seq[i]);
228 int subtree_end = target_buffers_.size();
229 f_check_subtree(subtree_start, subtree_end);
230 new_seq.push_back(new_seq_elem);
231 if (!new_seq_elem.same_as(seq->seq[i])) {
232 changed = true;
233 }
234 }
235 if (!changed) {
236 return GetRef<Stmt>(op);
237 }
238 } else {
239 int subtree_start = target_buffers_.size();
240 Stmt body = VisitStmt(op->body);
241 int subtree_end = target_buffers_.size();
242 f_check_subtree(subtree_start, subtree_end);
243 if (body.same_as(op->body)) {
244 return GetRef<Stmt>(op);
245 }
246 new_seq.push_back(body);
247 }
248
249 Block new_block = GetRef<Block>(op);
250 BlockNode* new_block_node = new_block.CopyOnWrite();
251 // Add new buffer allocations if any.
252 if (new_alloc_buffers.size() > 0) {
253 new_block_node->alloc_buffers = Concat(new_block_node->alloc_buffers, new_alloc_buffers);
254 }
255 new_block_node->body = new_seq.size() == 1 ? new_seq[0] : SeqStmt(new_seq);
256 return std::move(new_block);
257 }
258
259 std::vector<Stmt> ancestor_loop_or_blocks_; // ancestor loops or block realize
260 Map<Buffer, Buffer> buffer_remap_; // mapping from the target buffer to the intermediate buffer
261 Map<Buffer, Stmt> buffer_local_stage_; // mapping from the target buffer to the local stage
262 Array<Buffer> target_buffers_; // the target buffers for rewriting
263};
264
265namespace transform {
266
267Pass ManifestSharedMemoryLocalStage() {
268 auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
269 auto* n = f.CopyOnWrite();
270 n->body = SharedMemoryLocalStageInserter()(std::move(n->body));
271 return f;
272 };
273 return CreatePrimFuncPass(pass_func, 0, "tir.ManifestSharedMemoryLocalStage", {});
274}
275
276TVM_REGISTER_GLOBAL("tir.transform.ManifestSharedMemoryLocalStage")
277 .set_body_typed(ManifestSharedMemoryLocalStage);
278
279} // namespace transform
280} // namespace tir
281} // namespace tvm
282