1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file inject_software_pipeline.cc |
22 | * \brief Transform annotated loops into pipelined one that parallelize producers and consumers |
23 | */ |
24 | #include <tvm/target/target.h> |
25 | #include <tvm/tir/builtin.h> |
26 | #include <tvm/tir/transform.h> |
27 | |
28 | #include <unordered_set> |
29 | |
30 | #include "../../support/utils.h" |
31 | #include "../schedule/utils.h" |
32 | #include "./ir_utils.h" |
33 | |
34 | namespace tvm { |
35 | namespace tir { |
36 | |
37 | namespace software_pipeline { |
38 | |
39 | /*! |
40 | * \brief Create a block and infer the access region with the given body. |
41 | * |
42 | * The result is a opaque block that doesn't contain any block iter vars. In case the body is a |
43 | * block realize without predicate, it is unnecessary to create a new block, the block of the block |
44 | * realize will be returned. |
45 | * |
46 | * \param body The body of the block. |
47 | * \param buffer_data_to_buffer The map from buffer data to buffer. |
48 | * \return The result block. |
49 | */ |
50 | Block MakeBlock(const Stmt& body, const Map<Var, Buffer>& buffer_data_to_buffer) { |
51 | if (const BlockRealizeNode* block_realize = body.as<BlockRealizeNode>()) { |
52 | if (is_one(block_realize->predicate)) { |
53 | // no need to create a new block |
54 | return block_realize->block; |
55 | } |
56 | } |
57 | Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"" , /*body*/ body); |
58 | Array<Array<BufferRegion>> access = GetBlockReadWriteRegion(block, buffer_data_to_buffer); |
59 | BlockNode* n = block.CopyOnWrite(); |
60 | n->reads = access[0]; |
61 | n->writes = access[1]; |
62 | return block; |
63 | } |
64 | |
65 | /*! Structure that represents the provided annotation per block or loop. */ |
66 | struct PipelineAnnotation { |
67 | int stage; |
68 | int order; |
69 | bool async; |
70 | }; |
71 | |
72 | using PipelineInfo = std::unordered_map<Block, PipelineAnnotation, ObjectPtrHash, ObjectPtrEqual>; |
73 | |
74 | struct BufferAccessInfo { |
75 | int def = -1; // the defining stage of the buffer |
76 | int use = -1; // the last using stage of the buffer |
77 | }; |
78 | |
79 | class PipelineOpaqueAccessRewriter { |
80 | public: |
81 | /*! |
82 | * \brief Constructor |
83 | * \param buffer_data_to_buffer The map from buffer data to buffer. |
84 | * \param buffer_remap The map from original buffer to the buffer with updated shape for |
85 | * multi-versioning in the software pipeline. |
86 | * \param pipeline_loop The original loop to be software pipelined. |
87 | * \param fragment_info Information about tensor core fragment |
88 | */ |
89 | PipelineOpaqueAccessRewriter( |
90 | const Map<Var, Buffer>& buffer_data_to_buffer, const Map<Buffer, Buffer>& buffer_remap, |
91 | const For& pipeline_loop, |
92 | const std::unordered_map<const VarNode*, FragmentInfo>& fragment_info) |
93 | : buffer_data_to_buffer_(buffer_data_to_buffer), |
94 | buffer_remap_(buffer_remap), |
95 | pipeline_loop_(pipeline_loop), |
96 | fragment_info_(fragment_info) {} |
97 | |
98 | PrimExpr Rewrite(const Call& call) { |
99 | // Intrinsic calls should be handled explicitly here as they are opaque accesses to |
100 | // buffer. |
101 | static const auto& load_matrix_sync = builtin::tvm_load_matrix_sync(); |
102 | static const auto& store_matrix_sync = builtin::tvm_store_matrix_sync(); |
103 | static const auto& mma_sync = builtin::tvm_mma_sync(); |
104 | static const auto& access_ptr = builtin::tvm_access_ptr(); |
105 | static const auto& ptx_ldmatrix = builtin::ptx_ldmatrix(); |
106 | static const auto& ptx_mma = builtin::ptx_mma(); |
107 | if (call->op.same_as(load_matrix_sync) || call->op.same_as(store_matrix_sync)) { |
108 | const Buffer& buffer = buffer_data_to_buffer_.at(Downcast<Var>(call->args[0])); |
109 | auto it = buffer_remap_.find(buffer); |
110 | if (it != buffer_remap_.end()) { |
111 | Array<PrimExpr> new_args = call->args; |
112 | const Buffer& new_buffer = (*it).second; |
113 | new_args.Set(4, RewriteWmmaFragmentIndex(buffer, new_buffer, call->args[4])); |
114 | return Call(call->dtype, call->op, new_args, call->span); |
115 | } |
116 | } else if (call->op.same_as(mma_sync)) { |
117 | Array<PrimExpr> new_args = call->args; |
118 | for (int i = 0; i < 4; i++) { |
119 | const Var& buffer_var = Downcast<Var>(call->args[i * 2]); |
120 | const PrimExpr& index = call->args[i * 2 + 1]; |
121 | const Buffer& buffer = buffer_data_to_buffer_.at(buffer_var); |
122 | auto it = buffer_remap_.find(buffer); |
123 | if (it != buffer_remap_.end()) { |
124 | PrimExpr new_index = RewriteWmmaFragmentIndex(buffer, (*it).second, index); |
125 | new_args.Set(i * 2 + 1, new_index); |
126 | } |
127 | } |
128 | return Call(call->dtype, call->op, new_args, call->span); |
129 | } else if (call->op.same_as(access_ptr)) { |
130 | return RewriteBufferAccess(call, {1}); |
131 | } else if (call->op.same_as(ptx_mma)) { |
132 | return RewriteBufferAccess(call, {6, 8, 10}); |
133 | } else if (call->op.same_as(ptx_ldmatrix)) { |
134 | return RewriteBufferAccess(call, {3}); |
135 | } |
136 | return call; |
137 | } |
138 | |
139 | private: |
140 | int GetWmmaFragmentSize(const Buffer& buffer) { |
141 | auto it = fragment_info_.find(buffer->data.get()); |
142 | ICHECK(it != fragment_info_.end()); |
143 | const FragmentInfo& info = (*it).second; |
144 | return info.GetSize(); |
145 | } |
146 | |
147 | PrimExpr RewriteWmmaFragmentIndex(const Buffer& old_buffer, const Buffer& new_buffer, |
148 | const PrimExpr& old_index) { |
149 | PrimExpr new_buffer_offset = old_index; |
150 | |
151 | int fragment_size = GetWmmaFragmentSize(old_buffer); |
152 | PrimExpr offset = |
153 | floordiv(foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, |
154 | make_const(DataType::Int(32), 1), old_buffer->shape), |
155 | fragment_size); |
156 | new_buffer_offset += |
157 | floormod(pipeline_loop_->loop_var - pipeline_loop_->min, new_buffer->shape[0]) * offset; |
158 | return new_buffer_offset; |
159 | } |
160 | |
161 | PrimExpr RewriteBufferAccess(const Call& call, const std::vector<int> arg_indices) { |
162 | auto product = [](const Array<PrimExpr>& input) { |
163 | return foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, |
164 | make_const(DataType::Int(32), 1), input); |
165 | }; |
166 | Array<PrimExpr> new_args = call->args; |
167 | for (int i : arg_indices) { |
168 | const Buffer& buffer = buffer_data_to_buffer_.at(Downcast<Var>(call->args[i])); |
169 | auto it = buffer_remap_.find(buffer); |
170 | if (it != buffer_remap_.end()) { |
171 | const Buffer& new_buffer = (*it).second; |
172 | const PrimExpr& old_index = call->args[i + 1]; |
173 | PrimExpr offset; |
174 | if (new_buffer->strides.empty()) { |
175 | offset = product(buffer->shape); |
176 | } else { |
177 | offset = new_buffer->strides[0]; |
178 | } |
179 | PrimExpr new_index = |
180 | old_index + floormod(pipeline_loop_->loop_var, new_buffer->shape[0]) * offset; |
181 | new_args.Set(i + 1, new_index); |
182 | } |
183 | } |
184 | return Call(call->dtype, call->op, new_args, call->span); |
185 | } |
186 | |
187 | const Map<Var, Buffer>& buffer_data_to_buffer_; |
188 | const Map<Buffer, Buffer>& buffer_remap_; |
189 | const For& pipeline_loop_; |
190 | const std::unordered_map<const VarNode*, FragmentInfo>& fragment_info_; |
191 | }; |
192 | |
193 | /*! |
194 | * \brief Rewriter for the body of the software pipeline. This pass inserts `floormod` to indices |
195 | * of the remapped buffer to select the version corresponding to the pipeline stage. |
196 | */ |
197 | class PipelineBodyRewriter : public StmtExprMutator { |
198 | public: |
199 | /*! |
200 | * \brief Constructor of PipelineBodyRewriter. |
201 | * \param buffer_data_to_buffer The map from buffer data to buffer. |
202 | * \param buffer_remap The map from original buffer to the buffer with updated shape for |
203 | * multi-versioning in the software pipeline. |
204 | * \param pipeline_loop The original loop to be software pipelined. |
205 | * \param access_all_versions Whether all versions the buffers in the software pipeline are |
206 | * accessed. This will be used to update block access region. In the prologue and epilogue |
207 | * of a two-stage software pipeline, only one version of these buffers are accessed. |
208 | * \param fragment_info Information about tensor core fragment |
209 | */ |
210 | PipelineBodyRewriter(const Map<Var, Buffer>& buffer_data_to_buffer, |
211 | const Map<Buffer, Buffer>& buffer_remap, For pipeline_loop, |
212 | bool access_all_versions, |
213 | const std::unordered_map<const VarNode*, FragmentInfo>& fragment_info) |
214 | : buffer_data_to_buffer_(buffer_data_to_buffer), |
215 | buffer_remap_(buffer_remap), |
216 | pipeline_loop_(pipeline_loop), |
217 | access_all_versions_(access_all_versions), |
218 | opaque_access_rewriter_(buffer_data_to_buffer_, buffer_remap_, pipeline_loop_, |
219 | fragment_info) {} |
220 | |
221 | private: |
222 | BufferRegion RewritePipelineBufferRegion(const BufferRegion& buffer_region) const { |
223 | auto it = buffer_remap_.find(buffer_region->buffer); |
224 | if (it != buffer_remap_.end()) { |
225 | Region new_region = buffer_region->region; |
226 | const Buffer& new_buffer = (*it).second; |
227 | // For pipeline buffers, relax the access region of the first dimension to full extent |
228 | // if access_all_versions == true |
229 | Range accessed_version = |
230 | access_all_versions_ |
231 | ? Range::FromMinExtent(0, new_buffer->shape[0]) |
232 | : Range::FromMinExtent(floormod((pipeline_loop_->loop_var - pipeline_loop_->min), |
233 | new_buffer->shape[0]), |
234 | Integer(1)); |
235 | new_region.insert(new_region.begin(), accessed_version); |
236 | return BufferRegion(new_buffer, new_region); |
237 | } |
238 | return buffer_region; |
239 | } |
240 | |
241 | Stmt VisitStmt_(const BlockNode* op) final { |
242 | for (const Buffer& alloc_buffer : op->alloc_buffers) { |
243 | buffer_data_to_buffer_.Set(alloc_buffer->data, alloc_buffer); |
244 | } |
245 | Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op)); |
246 | BlockNode* n = block.CopyOnWrite(); |
247 | n->reads.MutateByApply([this](const BufferRegion& buffer_region) { |
248 | return RewritePipelineBufferRegion(buffer_region); |
249 | }); |
250 | n->writes.MutateByApply([this](const BufferRegion& buffer_region) { |
251 | return RewritePipelineBufferRegion(buffer_region); |
252 | }); |
253 | for (const Buffer& alloc_buffer : op->alloc_buffers) { |
254 | buffer_data_to_buffer_.erase(alloc_buffer->data); |
255 | } |
256 | return std::move(block); |
257 | } |
258 | |
259 | Stmt VisitStmt_(const BufferStoreNode* op) final { |
260 | BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op)); |
261 | auto it = buffer_remap_.find(store->buffer); |
262 | if (it == buffer_remap_.end()) { |
263 | return std::move(store); |
264 | } |
265 | const Buffer& new_buffer = (*it).second; |
266 | auto* n = store.CopyOnWrite(); |
267 | n->buffer = new_buffer; |
268 | PrimExpr version = |
269 | floormod((pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]); |
270 | n->indices.insert(n->indices.begin(), version); |
271 | return std::move(store); |
272 | } |
273 | |
274 | PrimExpr VisitExpr_(const BufferLoadNode* op) final { |
275 | BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); |
276 | auto it = buffer_remap_.find(load->buffer); |
277 | if (it == buffer_remap_.end()) { |
278 | return std::move(load); |
279 | } |
280 | const Buffer& new_buffer = (*it).second; |
281 | auto* n = load.CopyOnWrite(); |
282 | n->buffer = new_buffer; |
283 | PrimExpr version = |
284 | floormod((pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]); |
285 | n->indices.insert(n->indices.begin(), version); |
286 | return std::move(load); |
287 | } |
288 | |
289 | PrimExpr VisitExpr_(const CallNode* op) final { |
290 | Call call = Downcast<Call>(StmtExprMutator::VisitExpr_(op)); |
291 | return opaque_access_rewriter_.Rewrite(call); |
292 | } |
293 | |
294 | Map<Var, Buffer> buffer_data_to_buffer_; |
295 | Map<Buffer, Buffer> buffer_remap_; |
296 | For pipeline_loop_; |
297 | bool access_all_versions_; |
298 | PipelineOpaqueAccessRewriter opaque_access_rewriter_; |
299 | }; |
300 | |
301 | /*! |
302 | * \brief Rewriter for the software pipeline that rewrite a loop into a pipelined one. |
303 | */ |
304 | class PipelineRewriter : public StmtExprMutator { |
305 | public: |
306 | static Stmt Rewrite( |
307 | Map<Var, Buffer> buffer_data_to_buffer, |
308 | const std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual>& double_buffers, |
309 | const Array<Buffer> pipeline_allocs, const For& pipeline_loop, |
310 | const PipelineInfo& pipeline_info, |
311 | const std::unordered_map<const VarNode*, FragmentInfo>& fragment_info, |
312 | const Map<String, ObjectRef> preserved_annotations, bool merge_async_commit_queue_scope) { |
313 | PipelineRewriter rewriter(buffer_data_to_buffer, double_buffers, pipeline_allocs, pipeline_loop, |
314 | pipeline_info, fragment_info, preserved_annotations, |
315 | merge_async_commit_queue_scope); |
316 | return rewriter.BuildPipeline(); |
317 | } |
318 | |
319 | private: |
320 | PipelineRewriter(Map<Var, Buffer> buffer_data_to_buffer, |
321 | const std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual>& double_buffers, |
322 | const Array<Buffer>& pipeline_allocs, const For& pipeline_loop, |
323 | const PipelineInfo& pipeline_info, |
324 | const std::unordered_map<const VarNode*, FragmentInfo>& fragment_info, |
325 | const Map<String, ObjectRef> preserved_annotations, |
326 | bool merge_async_commit_queue_scope) |
327 | |
328 | : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)), |
329 | double_buffers_(double_buffers), |
330 | pipeline_allocs_(pipeline_allocs), |
331 | pipeline_loop_(pipeline_loop), |
332 | pipeline_info_(pipeline_info), |
333 | fragment_info_(fragment_info), |
334 | preserved_annotations_(preserved_annotations), |
335 | merge_async_commit_queue_scope_(merge_async_commit_queue_scope) {} |
336 | |
337 | Stmt BuildPipeline() { |
338 | // Step 1: Analyze accesses to the buffers in the pipeline and compute the number of versions |
339 | // need to maintain for each buffer. |
340 | std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual> infos = |
341 | GetBufferAccessInfo(); |
342 | for (const Buffer& buffer : pipeline_allocs_) { |
343 | int num_versions = ComputeBufferVersions(buffer, infos.at(buffer)); |
344 | if (num_versions > 1) { |
345 | buffer_remap_.Set(buffer, RewriteAllocBuffer(buffer, num_versions)); |
346 | } |
347 | } |
348 | |
349 | ordered_stmts_.resize(pipeline_info_.size()); |
350 | for (const auto& pair : pipeline_info_) { |
351 | const Block& block = pair.first; |
352 | int order = pair.second.order; |
353 | ordered_stmts_.Set(order, block); |
354 | } |
355 | |
356 | // Step 2: Emit the pipeline prologue, body and epilogue. |
357 | Stmt prologue = EmitImpl(pipeline_loop_->min, pipeline_loop_->min + max_stage_, true); |
358 | Stmt body = EmitImpl(pipeline_loop_->min + max_stage_, |
359 | pipeline_loop_->min + pipeline_loop_->extent, false); |
360 | Stmt epilogue = EmitImpl(pipeline_loop_->min + pipeline_loop_->extent, |
361 | pipeline_loop_->min + pipeline_loop_->extent + max_stage_, true); |
362 | |
363 | SeqStmt stmt = SeqStmt({prologue, body, epilogue}); |
364 | |
365 | // Step 3: Make a new block that contains new buffer allocations after pipeline rewriting. |
366 | Array<Buffer> alloc_buffers; |
367 | for (const auto& alloc : pipeline_allocs_) { |
368 | alloc_buffers.push_back(buffer_remap_.Get(alloc).value_or(alloc)); |
369 | buffer_data_to_buffer_.erase(alloc->data); |
370 | } |
371 | Block block = MakeBlock(stmt, buffer_data_to_buffer_); |
372 | block.CopyOnWrite()->alloc_buffers = std::move(alloc_buffers); |
373 | return BlockRealize({}, Bool(true), block); |
374 | } |
375 | |
376 | private: |
377 | /*! |
378 | * \brief Analyze accesses to the buffers in the software pipeline. |
379 | * |
380 | * This method check the 'define' and 'use' stage of the buffers in the software pipeline, which |
381 | * can be used to compute the number of versions needed to maintain after rewriting. |
382 | */ |
383 | std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual> |
384 | GetBufferAccessInfo() { |
385 | std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual> infos; |
386 | for (const auto& pair : pipeline_info_) { |
387 | const Block& block = pair.first; |
388 | int stage = pair.second.stage; |
389 | max_stage_ = std::max(max_stage_, stage); |
390 | |
391 | for (const BufferRegion& write : block->writes) { |
392 | if (!infos.count(write->buffer)) { |
393 | infos.emplace(write->buffer, BufferAccessInfo{}); |
394 | } |
395 | auto& info = infos.at(write->buffer); |
396 | if (info.def == -1) { |
397 | info.def = stage; |
398 | } else { |
399 | info.def = std::min(info.def, stage); |
400 | } |
401 | } |
402 | |
403 | for (const BufferRegion& read : block->reads) { |
404 | if (!infos.count(read->buffer)) { |
405 | infos.emplace(read->buffer, BufferAccessInfo{}); |
406 | } |
407 | auto& info = infos.at(read->buffer); |
408 | info.use = std::max(info.use, stage); |
409 | } |
410 | } |
411 | return infos; |
412 | } |
413 | |
414 | /*! |
415 | * \brief Check whether two regions have intersections. |
416 | * \param region1 The first region. |
417 | * \param region2 The second region. |
418 | * \return Whether region1 and region2 have intersections. |
419 | */ |
420 | bool MayConflict(Region region1, Region region2) { |
421 | ICHECK(region1.size() == region2.size()); |
422 | for (size_t i = 0; i < region1.size(); i++) { |
423 | Range dim1 = region1[i]; |
424 | Range dim2 = region2[i]; |
425 | auto int_set1 = arith::IntSet::FromRange(dim1); |
426 | auto int_set2 = arith::IntSet::FromRange(dim2); |
427 | if (arith::Intersect({int_set1, int_set2}).IsNothing()) { |
428 | return false; |
429 | } |
430 | } |
431 | return true; |
432 | } |
433 | |
434 | /*! |
435 | * \brief Compute the number of versions need to maintain for buffer accessed in the software |
436 | * pipeline. |
437 | * |
438 | * This method applies liveness analysis to the target buffer to compute the number of versions |
439 | * need to maintain during the software pipeline. |
440 | * Annotation `attr::double_buffer_scope` is handled here which provides a way to override the |
441 | * result of the analysis. Additional double buffering in the software pipeline can be useful |
442 | * to eliminate synchronizations in GPU devices. |
443 | * |
444 | * \param buffer The target buffer |
445 | * \param buffer_info The access information of the target buffer. |
446 | * \return The number of versions required for the target buffer. |
447 | */ |
448 | int ComputeBufferVersions(const Buffer& buffer, const BufferAccessInfo& buffer_info) { |
449 | if (buffer_info.def == -1) { |
450 | // Keep the original number of versions as buffers defined outside the software pipeline |
451 | // should not be mutated. |
452 | return 1; |
453 | } |
454 | |
455 | // `use - def + 1` is a upper bound of the needed versions |
456 | // We optimize a few case where the number of versions can be smaller than the upper bound |
457 | int num_versions = buffer_info.use - buffer_info.def + 1; |
458 | if (num_versions == 2) { |
459 | // A special case when `use - def + 1 == 2`. Double buffering is only needed in this case when |
460 | // these exists a reader block_i and a writer block_j such that |
461 | // order(block_i) < order(block_j) and stage(block_i) < stage(block_j) and the access regions |
462 | // of block_i and block_j overlap. |
463 | bool need_multi_version = false; |
464 | for (const auto& pair1 : pipeline_info_) { |
465 | const Block& writer_block = pair1.first; |
466 | const auto& writer_info = pair1.second; |
467 | |
468 | auto it1 = std::find_if(writer_block->writes.begin(), writer_block->writes.end(), |
469 | [&](const BufferRegion& buffer_region) { |
470 | return buffer_region->buffer.same_as(buffer); |
471 | }); |
472 | if (it1 == writer_block->writes.end()) { |
473 | continue; |
474 | } |
475 | |
476 | for (const auto& pair2 : pipeline_info_) { |
477 | const Block& reader_block = pair2.first; |
478 | const auto& reader_info = pair2.second; |
479 | auto it2 = std::find_if(reader_block->reads.begin(), reader_block->reads.end(), |
480 | [&](const BufferRegion& buffer_region) { |
481 | return buffer_region->buffer.same_as(buffer); |
482 | }); |
483 | if (it2 == reader_block->reads.end()) { |
484 | continue; |
485 | } |
486 | if (writer_info.order < reader_info.order && writer_info.stage < reader_info.stage && |
487 | MayConflict((*it1)->region, (*it2)->region)) { |
488 | need_multi_version = true; |
489 | break; |
490 | } |
491 | } |
492 | } |
493 | if (!need_multi_version) { |
494 | num_versions = 1; |
495 | } |
496 | } |
497 | if (num_versions == 1 && double_buffers_.count(buffer)) { |
498 | num_versions = 2; |
499 | } |
500 | return num_versions; |
501 | } |
502 | |
503 | /*! |
504 | * \brief Rewrite buffer allocation to keep multiple versions of original buffer for pipelined |
505 | * accesses. |
506 | * \param buffer The buffer to be resized. |
507 | * \param num_versions The number of versions to keep. |
508 | * \return The resized buffer. |
509 | */ |
510 | Buffer RewriteAllocBuffer(const Buffer& buffer, int num_versions) { |
511 | ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*(buffer.get())); |
512 | new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions)); |
513 | if (new_buffer->strides.size()) { |
514 | ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size()); |
515 | PrimExpr stride_0 = new_buffer->strides[0] * new_buffer->shape[1]; |
516 | new_buffer->strides.insert(new_buffer->strides.begin(), stride_0); |
517 | } |
518 | return Buffer(new_buffer); |
519 | } |
520 | |
521 | // Per-stage states that need to be tracked across pipeline prologue, body, and epilogue. |
522 | struct AsyncStateGlobal { |
523 | // Buffers that this stage asynchronously writes. |
524 | std::unordered_set<const BufferNode*> dst_buffers; |
525 | // An imaginary index that the latest async operation associated with this stage has written |
526 | // into. Only valid if all associated predicates are true, so that we can count the number of |
527 | // async invocations exactly. When it is valid, it is the "sum of extents of loops that have |
528 | // been executed" - 1, e.g. for epilogue it is prologue extent + body extent - 1. This |
529 | // is only needed to compute wait count for epilogue without async producers. |
530 | Optional<PrimExpr> producer_head{PrimExpr(-1)}; |
531 | |
532 | bool writes(Buffer buf) const { return dst_buffers.count(buf.get()) > 0; } |
533 | }; |
534 | |
535 | // Per-stage states that are local to each of pipeline prologue, body, and epilogue. |
536 | struct AsyncStateLocal { |
537 | struct { |
538 | // The index into a list of blocks, where async_wait_queue should be attached at the |
539 | // beginning. |
540 | int insert_before; |
541 | // in_flight_count would be a more precise name, but the implementation uses wait_count for |
542 | // brevity. |
543 | PrimExpr wait_count{nullptr}; |
544 | |
545 | bool valid() const { return wait_count.defined(); } |
546 | } pending_wait; |
547 | |
548 | // Destination buffers of async operations that have been encountered so far in the loop |
549 | // |
550 | // for (size_t i = 0; i < new_blocks.size(); ++i) { |
551 | // ... |
552 | // } |
553 | // |
554 | // This is for tracking which async operations have been issued at the "current" iteration, up |
555 | // until a point where we encounter a consumer of async result buffers. This is used to decide |
556 | // if the producer_head of each buffer points to a copy written in the current or previous |
557 | // iteration. |
558 | std::unordered_set<const BufferNode*> seen; |
559 | |
560 | // A symbolic expression representing the index the latest async operation associated with this |
561 | // stage has written into, at the "current" iteration. |
562 | Optional<PrimExpr> producer_head; |
563 | // The predicate of BlockRealize containing the async operation of this stage. |
564 | Optional<PrimExpr> predicate; |
565 | // Indices into a list of blocks, where async_commit_queue scope should be attached. |
566 | // If multiple async producers are interleaved with their consumer in between, we need separate |
567 | // async_commit_queue for each producer. Thus, we need multiple sets of indices. |
568 | std::vector<std::vector<size_t>> commit_groups; |
569 | |
570 | // This is set to true when we reach a stage that consumes this async stage. |
571 | bool consumed{false}; |
572 | }; |
573 | |
574 | /*! Structure holding intermediate information for pipeline loop rewriting. */ |
575 | struct RewrittenBlockInfo { |
576 | int stage; |
577 | PrimExpr predicate; |
578 | Block block; |
579 | PrimExpr access_index; |
580 | bool is_async; |
581 | }; |
582 | |
583 | // Determine where to insert async_wait and the corresponding wait count. |
584 | void PopulateWaitCounts(const std::vector<RewrittenBlockInfo>& new_blocks, |
585 | arith::Analyzer* ana_normalized, |
586 | const std::unordered_map<const BufferNode*, int>& buffer_to_commit_group, |
587 | std::map<int, AsyncStateLocal>* async_states_local) { |
588 | for (size_t i = 0; i < new_blocks.size(); ++i) { |
589 | if (new_blocks[i].is_async) { |
590 | // Record the fact that we have encountered these write buffers. |
591 | for (auto write_region : new_blocks[i].block->writes) { |
592 | (*async_states_local)[new_blocks[i].stage].seen.insert(write_region->buffer.get()); |
593 | } |
594 | } |
595 | |
596 | int producer_stage_idx = -1; |
597 | for (auto read_region : new_blocks[i].block->reads) { |
598 | for (auto kv : async_states) { |
599 | if (kv.first <= new_blocks[i].stage && kv.second.writes(read_region->buffer)) { |
600 | // Found an earlier stage where read_region->buffer was asynchronously written |
601 | ICHECK(producer_stage_idx == -1 || producer_stage_idx == kv.first) |
602 | << "A dependency on multiple async stages is not supported" ; |
603 | producer_stage_idx = kv.first; |
604 | } |
605 | } |
606 | } |
607 | |
608 | if (producer_stage_idx == -1) continue; |
609 | |
610 | // The following logic has become complicated to handle case like this: |
611 | // |
612 | // for i in range(13): |
613 | // # Stage 0 |
614 | // async_commit_queue(0): |
615 | // async_scope: |
616 | // A_shared[(i + 3) % 4] = A[...] |
617 | // |
618 | // |
619 | // # Stage 1 |
620 | // async_wait_queue(0, 5): |
621 | // compute(A_shared[i], B_shared[i]) |
622 | // |
623 | // # Stage 0 |
624 | // async_commit_queue(0) |
625 | // async_scope: |
626 | // B_shared[(i + 3) % 4] = B[...] |
627 | // |
628 | // |
629 | // Here, multiple async producers in the same stage are interleaved with their consumer in |
630 | // between. Since each buffer is associated with different commit groups, the wait_count |
631 | // before the consumer should be bigger than the simpler case: |
632 | // |
633 | // for i in range(13): |
634 | // # Stage 0 |
635 | // async_commit_queue(0): |
636 | // async_scope: |
637 | // A_shared[(i + 3) % 4] = A[...] |
638 | // B_shared[(i + 3) % 4] = B[...] |
639 | // |
640 | // # Stage 1 |
641 | // async_wait_queue(0, 3): |
642 | // compute(A_shared[i], B_shared[i]) |
643 | // |
644 | // The correct wait_count can be determined by considering each commit group separately, and |
645 | // summing "per-commit" wait_counts. |
646 | // |
647 | // From A_shared's perspective, it allows for (i + 3) - i async commit groups to be in |
648 | // flight while from B_shared's perspective, the producer head at compute points to the copy |
649 | // done by the previous iteration, so its wait_count is calculated as ((i - 1) + 3) - i. The |
650 | // sum of the two wait_counts gives 5. |
651 | |
652 | auto& dep_local_state = (*async_states_local)[producer_stage_idx]; |
653 | const auto num_commit_group = dep_local_state.commit_groups.size(); |
654 | std::vector<Optional<PrimExpr>> producer_head_per_commit; |
655 | |
656 | if (num_commit_group == 0) { |
657 | // Epilogue, no async producer. Since "local" producer_head is not available, use |
658 | // "global" producer_head. |
659 | ICHECK(!dep_local_state.producer_head); |
660 | producer_head_per_commit.push_back(async_states[producer_stage_idx].producer_head); |
661 | } else { |
662 | ICHECK(dep_local_state.producer_head); |
663 | std::vector<bool> need_wait_count(num_commit_group, true); |
664 | |
665 | for (auto read_region : new_blocks[i].block->reads) { |
666 | if (!async_states[producer_stage_idx].writes(read_region->buffer)) continue; |
667 | auto commit_group_id = buffer_to_commit_group.at(read_region->buffer.get()); |
668 | if (!need_wait_count[commit_group_id]) continue; |
669 | |
670 | if (!dep_local_state.seen.count(read_region->buffer.get())) { |
671 | // Multiple async producers interleaved: The most recent async write is from the |
672 | // previous iteration. This is the B_shared case above. |
673 | producer_head_per_commit.push_back(dep_local_state.producer_head.value() - 1); |
674 | } else { |
675 | // Normal case |
676 | producer_head_per_commit.push_back(dep_local_state.producer_head.value()); |
677 | } |
678 | |
679 | need_wait_count[commit_group_id] = false; |
680 | } |
681 | } |
682 | |
683 | auto wait_count = [=, &ana_normalized]() { |
684 | auto sum = PrimExpr(0); |
685 | for (auto producer_head : producer_head_per_commit) { |
686 | if (producer_head && ana_normalized->CanProve(producer_head.value() >= 0)) { |
687 | // Here, new_blocks[i].access_index corresponds to "consumer_head". |
688 | // The difference of producer_head and consumer_head is precisely the number of |
689 | // async commit groups that can still be in flight after this wait. |
690 | sum += analyzer_.Simplify(producer_head.value() - new_blocks[i].access_index); |
691 | } else { |
692 | // The precise count cannot be determined, give up. |
693 | return PrimExpr(0); |
694 | } |
695 | } |
696 | return sum; |
697 | }(); |
698 | |
699 | auto& pending_wait = dep_local_state.pending_wait; |
700 | |
701 | if (!pending_wait.valid()) { |
702 | pending_wait = {static_cast<int>(i), wait_count}; |
703 | } else if (analyzer_.CanProve(wait_count < pending_wait.wait_count)) { |
704 | // Coalesce multiple wait_queue if the later one allows fewer in-flight ops. |
705 | pending_wait = {pending_wait.insert_before, wait_count}; |
706 | } |
707 | } |
708 | } |
709 | |
710 | // Given pipelined blocks and async-related information, generate final loop statements with async |
711 | // scopes (if any). |
712 | Array<Stmt> CompletePipelineLoopStatements( |
713 | const std::vector<RewrittenBlockInfo>& blocks, |
714 | const std::map<int, AsyncStateLocal>& async_states_local, |
715 | arith::Analyzer* ana_normalized) const { |
716 | std::vector<RewrittenBlockInfo> new_blocks = blocks; |
717 | std::vector<int> commit_group_indices(new_blocks.size(), -1); |
718 | for (const auto& [stage_id, state] : async_states_local) { |
719 | if (!state.commit_groups.empty()) { |
720 | for (size_t i = 0; i < state.commit_groups.size(); ++i) { |
721 | for (size_t j = 0; j < state.commit_groups[i].size(); ++j) { |
722 | ICHECK(state.commit_groups[i][0] + j < new_blocks.size()); |
723 | commit_group_indices[state.commit_groups[i][0] + j] = stage_id; |
724 | } |
725 | } |
726 | } |
727 | |
728 | if (state.pending_wait.valid()) { |
729 | auto attach_wait_scope = [&new_blocks](int i, int stage_id, PrimExpr wait_count) { |
730 | auto& block = new_blocks[i].block; |
731 | BlockNode* n = block.CopyOnWrite(); |
732 | auto zero = make_zero(DataType::Int(32)); |
733 | n->body = |
734 | AttrStmt(zero, tir::attr::async_wait_queue_scope, stage_id, |
735 | AttrStmt(zero, tir::attr::async_wait_inflight_count, wait_count, n->body)); |
736 | }; |
737 | |
738 | if (state.predicate && !ana_normalized->CanProve(state.predicate.value())) { |
739 | // If the async operation that this wait_queue is waiting on is predicated, and we cannot |
740 | // prove that the predicate is always true, the precise wait count is only valid |
741 | // at iterations where the predicate is true; |
742 | auto wait_count = Call(DataType::Int(32), builtin::if_then_else(), |
743 | {state.predicate.value(), state.pending_wait.wait_count, 0}); |
744 | attach_wait_scope(state.pending_wait.insert_before, stage_id, wait_count); |
745 | } else { |
746 | attach_wait_scope(state.pending_wait.insert_before, stage_id, |
747 | state.pending_wait.wait_count); |
748 | } |
749 | } |
750 | } |
751 | |
752 | Array<Stmt> stmts; |
753 | |
754 | for (size_t i = 0; i < new_blocks.size();) { |
755 | if (commit_group_indices[i] == -1) { |
756 | // A synchrnous block, not part of any commit group |
757 | stmts.push_back(BlockRealize({}, new_blocks[i].predicate, new_blocks[i].block)); |
758 | ++i; |
759 | } else { |
760 | Array<Stmt> group_bodies; |
761 | auto stage_id = commit_group_indices[i]; |
762 | auto predicate = new_blocks[i].predicate; |
763 | for (; i < commit_group_indices.size() && commit_group_indices[i] == stage_id; ++i) { |
764 | ICHECK(tvm::StructuralEqual()(predicate, new_blocks[i].predicate)) |
765 | << "Predicates in the same stage are expected to be identical" ; |
766 | group_bodies.push_back(new_blocks[i].block->body); |
767 | } |
768 | |
769 | if (merge_async_commit_queue_scope_ && group_bodies.size() > 1) { |
770 | auto merged_bodies = SeqStmt(group_bodies); |
771 | group_bodies.clear(); |
772 | group_bodies.push_back(merged_bodies); |
773 | } |
774 | |
775 | for (auto body : group_bodies) { |
776 | auto commit_queue_scope = AttrStmt(make_zero(DataType::Int(32)), |
777 | tir::attr::async_commit_queue_scope, stage_id, body); |
778 | auto new_block = MakeBlock(commit_queue_scope, buffer_data_to_buffer_); |
779 | stmts.push_back(BlockRealize({}, predicate, new_block)); |
780 | } |
781 | } |
782 | } |
783 | |
784 | return stmts; |
785 | } |
786 | |
787 | /*! |
788 | * \brief Emit the pipeline loop in the given range. |
789 | * \param start The start of the range |
790 | * \param end The end of the range |
791 | * \param unroll_loop Whether the loop should be unrolled. |
792 | * \return The result loop. |
793 | */ |
794 | Stmt EmitImpl(PrimExpr start, PrimExpr end, bool unroll_loop) { |
795 | PrimExpr new_loop_var; |
796 | PrimExpr extent = end - start; |
797 | |
798 | auto make_nop = []() { return BlockRealize({}, Bool(true), MakeBlock(Evaluate(0), {})); }; |
799 | |
800 | if (!analyzer_.CanProve(extent > 0)) { |
801 | return make_nop(); |
802 | } |
803 | bool is_unit_loop = analyzer_.CanProveEqual(extent, 1); |
804 | if (is_unit_loop) { |
805 | new_loop_var = start; // use constants as the loop var for unit loops |
806 | } else { |
807 | new_loop_var = pipeline_loop_->loop_var.copy_with_suffix("" ); |
808 | analyzer_.Bind(Downcast<Var>(new_loop_var), Range(start, end)); |
809 | } |
810 | |
811 | // In contrast to analyzer_ which is bound to [start, end), this one is bound to |
812 | // the "normalized" range, [pipeline_loop_->min, extent). |
813 | arith::Analyzer ana_normalized; |
814 | if (!is_unit_loop) { |
815 | ana_normalized.Bind(Downcast<Var>(new_loop_var), Range(pipeline_loop_->min, extent)); |
816 | } |
817 | |
818 | std::vector<RewrittenBlockInfo> new_blocks; |
819 | |
820 | // Async related |
821 | std::map<int, AsyncStateLocal> async_states_local; |
822 | std::unordered_map<const BufferNode*, int> buffer_to_commit_group; |
823 | |
824 | for (const Block& block : ordered_stmts_) { |
825 | int stage = pipeline_info_.at(block).stage; |
826 | PrimExpr skewed_loop_var = new_loop_var - stage; |
827 | PrimExpr inbound = analyzer_.Simplify(pipeline_loop_->min <= skewed_loop_var) && |
828 | (skewed_loop_var < pipeline_loop_->min + pipeline_loop_->extent); |
829 | if (analyzer_.CanProve(!inbound)) { |
830 | continue; |
831 | } |
832 | Block new_block = Downcast<Block>(PipelineBodyRewriter(buffer_data_to_buffer_, buffer_remap_, |
833 | pipeline_loop_, max_stage_ != 1, |
834 | fragment_info_)(block)); |
835 | |
836 | PrimExpr delta = start - pipeline_loop_->min; |
837 | // This variable corresponds to |
838 | // - "producer_head" if this stage is an async producer |
839 | // - "consumer_head" if this stage reads from asynchronously written buffers. |
840 | PrimExpr normalized_access_index = is_unit_loop ? skewed_loop_var : skewed_loop_var + delta; |
841 | |
842 | // Adjust the block predicate and the body according to the final loop bound |
843 | // [pipeline_loop_->min, extent). |
844 | if (!is_unit_loop) { |
845 | Var loop_iter = Downcast<Var>(new_loop_var); |
846 | inbound = Substitute(inbound, {{loop_iter, loop_iter + delta}}); |
847 | } |
848 | |
849 | new_block = Downcast<Block>( |
850 | Substitute(new_block, {{pipeline_loop_->loop_var, normalized_access_index}})); |
851 | |
852 | if (pipeline_info_[block].async) { |
853 | auto& local_state = async_states_local[stage]; |
854 | |
855 | int commit_group_id = -1; |
856 | if (local_state.commit_groups.empty() || local_state.consumed || |
857 | !merge_async_commit_queue_scope_) { |
858 | // consumed == true means there is already a consumer stage waiting for an |
859 | // eariler async operation of this stage. In such cases, we make multiple commit_queue |
860 | // for this stage. |
861 | commit_group_id = local_state.commit_groups.size(); |
862 | local_state.commit_groups.push_back({new_blocks.size()}); |
863 | } else { |
864 | // This is the case when one commit_queue groups multiple async blocks. |
865 | // with commit_queue(stage): |
866 | // async_scope: |
867 | // A_shared[...] = ... |
868 | // async_scope: |
869 | // B_shared[...] = ... |
870 | |
871 | commit_group_id = local_state.commit_groups.size() - 1; |
872 | local_state.commit_groups.back().push_back(new_blocks.size()); |
873 | } |
874 | |
875 | for (auto write_region : new_block->writes) { |
876 | async_states[stage].dst_buffers.insert(write_region->buffer.get()); |
877 | buffer_to_commit_group[write_region->buffer.get()] = commit_group_id; |
878 | } |
879 | |
880 | local_state.producer_head = normalized_access_index; |
881 | |
882 | if (!local_state.predicate || ana_normalized.CanProve(local_state.predicate.value())) { |
883 | local_state.predicate = inbound; |
884 | } else if (local_state.predicate) { |
885 | local_state.predicate = ana_normalized.Simplify(local_state.predicate.value() & inbound); |
886 | } |
887 | |
888 | BlockNode* n = new_block.CopyOnWrite(); |
889 | n->body = AttrStmt(make_zero(DataType::Int(32)), tir::attr::async_scope, 1, n->body); |
890 | } |
891 | |
892 | new_blocks.push_back( |
893 | {stage, inbound, new_block, normalized_access_index, pipeline_info_[block].async}); |
894 | |
895 | for (auto read_region : new_block->reads) { |
896 | for (auto kv : async_states) { |
897 | int producer_stage_id = kv.first; |
898 | if (producer_stage_id <= stage && kv.second.writes(read_region->buffer)) { |
899 | async_states_local[producer_stage_id].consumed = true; |
900 | } |
901 | } |
902 | } |
903 | } |
904 | |
905 | PopulateWaitCounts(new_blocks, &ana_normalized, buffer_to_commit_group, &async_states_local); |
906 | auto stmts = CompletePipelineLoopStatements(new_blocks, async_states_local, &ana_normalized); |
907 | |
908 | Stmt new_loop{nullptr}; |
909 | |
910 | if (stmts.empty()) { |
911 | return make_nop(); |
912 | } |
913 | if (stmts.size() == 1) { |
914 | new_loop = stmts[0]; |
915 | } else { |
916 | new_loop = SeqStmt(stmts); |
917 | } |
918 | |
919 | if (!is_unit_loop) { |
920 | new_loop = For(Downcast<Var>(new_loop_var), pipeline_loop_->min, extent, |
921 | unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind, std::move(new_loop), |
922 | NullOpt, preserved_annotations_); |
923 | } |
924 | |
925 | // Update producer heads in the global async states. |
926 | for (const auto& kv : async_states_local) { |
927 | const int stage_id = kv.first; |
928 | const AsyncStateLocal& state = kv.second; |
929 | |
930 | if (state.predicate && ana_normalized.CanProve(state.predicate.value()) && |
931 | async_states[stage_id].producer_head) { |
932 | // Advance the "global" producer head if it is still valid and we know exactly how much we |
933 | // can increment |
934 | async_states[stage_id].producer_head = |
935 | async_states[stage_id].producer_head.value() + extent; |
936 | } else { |
937 | // Otherwise, invalidate the global producer head |
938 | async_states[stage_id].producer_head = NullOpt; |
939 | } |
940 | } |
941 | |
942 | return BlockRealize({}, Bool(true), MakeBlock(std::move(new_loop), buffer_data_to_buffer_)); |
943 | } |
944 | |
945 | arith::Analyzer analyzer_; |
946 | Map<Var, Buffer> buffer_data_to_buffer_; |
947 | const std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual>& double_buffers_; |
948 | Array<Buffer> pipeline_allocs_; |
949 | For pipeline_loop_; |
950 | PipelineInfo pipeline_info_; |
951 | const std::unordered_map<const VarNode*, FragmentInfo>& fragment_info_; |
952 | int max_stage_ = -1; |
953 | Map<Buffer, Buffer> buffer_remap_; |
954 | Array<Block> ordered_stmts_; |
955 | std::map<int, AsyncStateGlobal> async_states; |
956 | Map<String, ObjectRef> preserved_annotations_; |
957 | bool merge_async_commit_queue_scope_ = true; |
958 | }; |
959 | |
960 | /*! |
961 | * \brief Build the dependency graph among a array of blocks. |
962 | * \param[in] blocks The array of blocks. |
963 | * \param[out] dep_src2dst Optional, a map to store dependency edges from the source to the |
964 | * destination. |
965 | * \param[out] dep_dst2src Optional, a map to store dependency edges from the |
966 | * destination to the source. |
967 | */ |
968 | void BuildDependencyGraph( |
969 | const Array<Block>& blocks, |
970 | std::unordered_map<Block, Array<Block>, ObjectPtrHash, ObjectPtrEqual>* dep_src2dst, |
971 | std::unordered_map<Block, Array<Block>, ObjectPtrHash, ObjectPtrEqual>* dep_dst2src) { |
972 | std::unordered_map<Var, Array<Block>, ObjectPtrHash, ObjectPtrEqual> buffer_writers; |
973 | |
974 | for (const Block& block : blocks) { |
975 | for (const BufferRegion& read : block->reads) { |
976 | auto it = buffer_writers.find(read->buffer->data); |
977 | if (it != buffer_writers.end()) { |
978 | for (const Block& writer : it->second) { |
979 | if (dep_src2dst != nullptr) { |
980 | (*dep_src2dst)[writer].push_back(block); |
981 | } |
982 | if (dep_dst2src != nullptr) { |
983 | (*dep_dst2src)[block].push_back(writer); |
984 | } |
985 | } |
986 | } |
987 | } |
988 | for (const BufferRegion& write : block->writes) { |
989 | buffer_writers[write->buffer->data].push_back(block); |
990 | } |
991 | } |
992 | } |
993 | |
994 | class PipelineInjector : private StmtExprMutator { |
995 | public: |
996 | static Stmt Inject(const PrimFunc& func, bool merge_async_commit_queue_scope) { |
997 | PipelineInjector injector(merge_async_commit_queue_scope); |
998 | for (const auto& kv : func->buffer_map) { |
999 | const Buffer& buffer = kv.second; |
1000 | injector.buffer_data_to_buffer_.Set(buffer->data, buffer); |
1001 | } |
1002 | injector.fragment_info_ = GetTensorCoreFragmentInfo(func->body); |
1003 | return injector(func->body); |
1004 | } |
1005 | |
1006 | private: |
1007 | explicit PipelineInjector(bool merge_async_commit_queue_scope) |
1008 | : merge_async_commit_queue_scope_(merge_async_commit_queue_scope) {} |
1009 | |
1010 | /*! |
1011 | * \brief Check the pipeline satisfies the following conditions: |
1012 | * 1. No conflicting order: The order of each statement should be unique. |
1013 | * 2. Reordering of statements doesn't break buffer access dependencies. Specifically, for |
1014 | * dependency (e.g. read-after-write) from statement A to statement B, it requires: |
1015 | * case 1: stage(A) < stage(B) |
1016 | * case 2: stage(A) == stage(B) and order(A) < order(B) |
1017 | */ |
1018 | void ValidatePipelineBody(const PipelineInfo& pipeline_info, const Array<Block>& original_order) { |
1019 | std::unordered_set<int> used_orders; |
1020 | std::unordered_map<int, int> stage_max_order; |
1021 | std::unordered_map<int, const Block*> order_to_block; |
1022 | std::unordered_map<const Block*, int> block_to_stage; |
1023 | for (const Block& block : original_order) { |
1024 | const auto& stmt_info = pipeline_info.at(block); |
1025 | int order = stmt_info.order; |
1026 | CHECK(!used_orders.count(order)) |
1027 | << "ValueError: Two statements in the software pipeline cannot have the same order" ; |
1028 | used_orders.insert(order); |
1029 | } |
1030 | |
1031 | std::unordered_map<Block, Array<Block>, ObjectPtrHash, ObjectPtrEqual> dep_src2dst; |
1032 | BuildDependencyGraph(original_order, &dep_src2dst, nullptr); |
1033 | |
1034 | for (const auto& pair : dep_src2dst) { |
1035 | const Block& src = pair.first; |
1036 | const auto& src_info = pipeline_info.at(src); |
1037 | const Array<Block>& dsts = pair.second; |
1038 | for (const Block& dst : dsts) { |
1039 | const auto& dst_info = pipeline_info.at(dst); |
1040 | CHECK_LE(src_info.stage, dst_info.stage) |
1041 | << "ValueError: statement " << dst << " in stage " << dst_info.stage |
1042 | << " cannot depends on statement " << src << " in a later stage " << src_info.stage; |
1043 | if (src_info.stage == dst_info.stage) { |
1044 | CHECK_LT(src_info.order, dst_info.order) << "ValueError: two statements with buffer " |
1045 | "access dependency in the same stage of the " |
1046 | "software pipeline cannot be reordered" ; |
1047 | } |
1048 | } |
1049 | } |
1050 | } |
1051 | |
1052 | Stmt VisitStmt_(const ForNode* op) final { |
1053 | // Step 1: Recursively rewrite the children first. |
1054 | For for_node = Downcast<For>(StmtExprMutator::VisitStmt_(op)); |
1055 | if (!HasPipelineAnnotation(op)) { |
1056 | return std::move(for_node); |
1057 | } |
1058 | // Step 2: Find the body and buffer allocations of the pipeline. The body can be direct child of |
1059 | // the for-loop. If the for-loop has BlockRealize as its child, the pipeline body will be the |
1060 | // child of the block. |
1061 | Stmt pipeline_body{nullptr}; |
1062 | Array<Buffer> pipeline_allocs; |
1063 | if (const auto* realize = for_node->body.as<BlockRealizeNode>()) { |
1064 | const auto& block = realize->block; |
1065 | for (const auto& buffer : block->alloc_buffers) { |
1066 | ICHECK(buffer->IsInstance<BufferNode>()); |
1067 | buffer_data_to_buffer_.Set(buffer->data, buffer); |
1068 | } |
1069 | pipeline_body = block->body; |
1070 | pipeline_allocs = block->alloc_buffers; |
1071 | } else { |
1072 | pipeline_body = for_node->body; |
1073 | } |
1074 | |
1075 | const SeqStmtNode* pipeline_body_seq = pipeline_body.as<SeqStmtNode>(); |
1076 | CHECK(pipeline_body_seq) |
1077 | << "ValueError: The body of the software pipeline should be SeqStmt, got " |
1078 | << pipeline_body->GetTypeKey(); |
1079 | |
1080 | // Step 3: Blockize the components of the pipeline. Each child of the pipelined loop will be |
1081 | // converted into a block. |
1082 | PipelineInfo pipeline_info; |
1083 | Array<Block> original_order; // pipeline body blocks in the original order |
1084 | |
1085 | auto f_add_child = [&](const Stmt& child) { |
1086 | original_order.push_back(MakeBlock(child, buffer_data_to_buffer_)); |
1087 | }; |
1088 | for (size_t i = 0; i < pipeline_body_seq->seq.size(); i++) { |
1089 | const auto* nested_block_realize = pipeline_body_seq->seq[i].as<BlockRealizeNode>(); |
1090 | if (nested_block_realize && is_one(nested_block_realize->predicate) && |
1091 | nested_block_realize->block->body->IsInstance<SeqStmtNode>()) { |
1092 | const Block& nested_pipeline_block = nested_block_realize->block; |
1093 | ICHECK( |
1094 | nested_pipeline_block->match_buffers.empty()); // match_buffer should have been lowered |
1095 | for (const auto& buffer : nested_pipeline_block->alloc_buffers) { |
1096 | pipeline_allocs.push_back(buffer); |
1097 | buffer_data_to_buffer_.Set(buffer->data, buffer); |
1098 | } |
1099 | const auto* nested_seq = nested_pipeline_block->body.as<SeqStmtNode>(); |
1100 | for (size_t j = 0; j < nested_seq->seq.size(); j++) { |
1101 | f_add_child(nested_seq->seq[j]); |
1102 | } |
1103 | } else { |
1104 | f_add_child(pipeline_body_seq->seq[i]); |
1105 | } |
1106 | } |
1107 | |
1108 | auto pipeline_stages = |
1109 | Downcast<Array<Integer>>(op->annotations.at(attr::software_pipeline_stage)); |
1110 | auto pipeline_orders = |
1111 | Downcast<Array<Integer>>(op->annotations.at(attr::software_pipeline_order)); |
1112 | CHECK_EQ(pipeline_stages.size(), original_order.size()); |
1113 | CHECK_EQ(pipeline_orders.size(), original_order.size()); |
1114 | |
1115 | std::unordered_set<int> pipeline_async_stages; |
1116 | if (auto annot = op->annotations.Get(attr::software_pipeline_async_stages)) { |
1117 | for (auto s : Downcast<Array<Integer>>(annot)) { |
1118 | pipeline_async_stages.insert(s->value); |
1119 | } |
1120 | } |
1121 | |
1122 | Map<String, ObjectRef> preserved_annotations; |
1123 | for (const auto& kv : op->annotations) { |
1124 | const String& key = kv.first; |
1125 | if (kv.first != attr::software_pipeline_stage && kv.first != attr::software_pipeline_order && |
1126 | kv.first != attr::software_pipeline_async_stages) { |
1127 | preserved_annotations.Set(key, kv.second); |
1128 | } |
1129 | } |
1130 | |
1131 | for (size_t i = 0; i < pipeline_stages.size(); i++) { |
1132 | int stage = static_cast<int>(pipeline_stages[i]->value); |
1133 | bool is_async = pipeline_async_stages.find(stage) != pipeline_async_stages.end(); |
1134 | PipelineAnnotation stage_order{stage, |
1135 | /*order=*/static_cast<int>(pipeline_orders[i]->value), |
1136 | is_async}; |
1137 | pipeline_info.emplace(original_order[i], stage_order); |
1138 | } |
1139 | |
1140 | ValidatePipelineBody(pipeline_info, original_order); |
1141 | |
1142 | // Step 4: Rewrite the pipeline body. |
1143 | Stmt pipeline = PipelineRewriter::Rewrite( |
1144 | buffer_data_to_buffer_, double_buffers, pipeline_allocs, GetRef<For>(op), pipeline_info, |
1145 | fragment_info_, preserved_annotations, merge_async_commit_queue_scope_); |
1146 | |
1147 | if (const auto* realize = op->body.as<BlockRealizeNode>()) { |
1148 | const auto& block = realize->block; |
1149 | for (const auto& buffer : block->alloc_buffers) { |
1150 | buffer_data_to_buffer_.erase(buffer->data); |
1151 | } |
1152 | } |
1153 | return pipeline; |
1154 | } |
1155 | |
1156 | /*! |
1157 | * \brief Add buffer allocations to a block and update the write region of the block. |
1158 | * \param n The block pointer to which the buffer allocations are added. |
1159 | * \param alloc_buffers The buffer allocations to be added. |
1160 | */ |
1161 | void AddAllocBuffers(BlockNode* n, const Array<Buffer> alloc_buffers) { |
1162 | for (const Buffer& alloc_buffer : alloc_buffers) { |
1163 | n->alloc_buffers.push_back(alloc_buffer); |
1164 | Region region; |
1165 | region.reserve(alloc_buffer->shape.size()); |
1166 | for (const PrimExpr& dim : alloc_buffer->shape) { |
1167 | region.push_back(Range::FromMinExtent(0, dim)); |
1168 | } |
1169 | n->writes.push_back(BufferRegion(alloc_buffer, region)); |
1170 | } |
1171 | } |
1172 | |
1173 | Stmt VisitStmt_(const BlockNode* op) final { |
1174 | for (const auto& buffer : op->alloc_buffers) { |
1175 | buffer_data_to_buffer_.Set(buffer->data, buffer); |
1176 | } |
1177 | |
1178 | auto it = op->annotations.find(attr::double_buffer_scope); |
1179 | if (it != op->annotations.end()) { |
1180 | int buffer_index = Downcast<Integer>((*it).second).IntValue(); |
1181 | CHECK(buffer_index >= 0 && static_cast<size_t>(buffer_index) < op->writes.size()) |
1182 | << "ValueError: Index of the buffer exceeds the size of the write regions of the block. (" |
1183 | << buffer_index << " vs. " << op->writes.size() << ")" ; |
1184 | double_buffers.insert(op->writes[buffer_index]->buffer); |
1185 | } |
1186 | Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op)); |
1187 | |
1188 | for (const auto& buffer : op->alloc_buffers) { |
1189 | buffer_data_to_buffer_.erase(buffer->data); |
1190 | } |
1191 | return std::move(block); |
1192 | } |
1193 | |
1194 | bool HasPipelineAnnotation(const ForNode* op) const { |
1195 | auto it1 = op->annotations.find(attr::software_pipeline_stage); |
1196 | auto it2 = op->annotations.find(attr::software_pipeline_order); |
1197 | bool has_stage = it1 != op->annotations.end(); |
1198 | bool has_order = it2 != op->annotations.end(); |
1199 | if (has_stage && has_order) { |
1200 | return true; |
1201 | } |
1202 | if (has_stage) { |
1203 | LOG(FATAL) << "ValueError: Order of the software pipeline is not defined." ; |
1204 | } |
1205 | if (has_order) { |
1206 | LOG(FATAL) << "ValueError: Stage of the software pipeline is not defined." ; |
1207 | } |
1208 | return false; |
1209 | } |
1210 | |
1211 | Map<Var, Buffer> buffer_data_to_buffer_; |
1212 | std::unordered_map<const VarNode*, FragmentInfo> fragment_info_; |
1213 | std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> double_buffers; |
1214 | bool merge_async_commit_queue_scope_ = true; |
1215 | }; |
1216 | |
1217 | } // namespace software_pipeline |
1218 | |
1219 | namespace transform { |
1220 | |
1221 | /*! |
1222 | * \brief Transform annotated loops into pipelined one that parallelize producers and consumers. |
1223 | * \return The IR transform pass. |
1224 | */ |
1225 | Pass InjectSoftwarePipeline() { |
1226 | auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { |
1227 | auto* fptr = f.CopyOnWrite(); |
1228 | bool merge_async_commit_queue_scope = |
1229 | ctx->GetConfig<Bool>("tir.merge_async_commit_queue_scope" , Bool(true)).value(); |
1230 | fptr->body = software_pipeline::PipelineInjector::Inject(f, merge_async_commit_queue_scope); |
1231 | fptr->body = ConvertSSA(std::move(fptr->body)); |
1232 | return f; |
1233 | }; |
1234 | return CreatePrimFuncPass(pass_func, 0, "tir.InjectSoftwarePipeline" , {}); |
1235 | } |
1236 | |
1237 | TVM_REGISTER_GLOBAL("tir.transform.InjectSoftwarePipeline" ).set_body_typed(InjectSoftwarePipeline); |
1238 | |
1239 | } // namespace transform |
1240 | |
1241 | } // namespace tir |
1242 | } // namespace tvm |
1243 | |