1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file inject_software_pipeline.cc
22 * \brief Transform annotated loops into pipelined one that parallelize producers and consumers
23 */
24#include <tvm/target/target.h>
25#include <tvm/tir/builtin.h>
26#include <tvm/tir/transform.h>
27
28#include <unordered_set>
29
30#include "../../support/utils.h"
31#include "../schedule/utils.h"
32#include "./ir_utils.h"
33
34namespace tvm {
35namespace tir {
36
37namespace software_pipeline {
38
39/*!
40 * \brief Create a block and infer the access region with the given body.
41 *
42 * The result is a opaque block that doesn't contain any block iter vars. In case the body is a
43 * block realize without predicate, it is unnecessary to create a new block, the block of the block
44 * realize will be returned.
45 *
46 * \param body The body of the block.
47 * \param buffer_data_to_buffer The map from buffer data to buffer.
48 * \return The result block.
49 */
50Block MakeBlock(const Stmt& body, const Map<Var, Buffer>& buffer_data_to_buffer) {
51 if (const BlockRealizeNode* block_realize = body.as<BlockRealizeNode>()) {
52 if (is_one(block_realize->predicate)) {
53 // no need to create a new block
54 return block_realize->block;
55 }
56 }
57 Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", /*body*/ body);
58 Array<Array<BufferRegion>> access = GetBlockReadWriteRegion(block, buffer_data_to_buffer);
59 BlockNode* n = block.CopyOnWrite();
60 n->reads = access[0];
61 n->writes = access[1];
62 return block;
63}
64
65/*! Structure that represents the provided annotation per block or loop. */
66struct PipelineAnnotation {
67 int stage;
68 int order;
69 bool async;
70};
71
72using PipelineInfo = std::unordered_map<Block, PipelineAnnotation, ObjectPtrHash, ObjectPtrEqual>;
73
74struct BufferAccessInfo {
75 int def = -1; // the defining stage of the buffer
76 int use = -1; // the last using stage of the buffer
77};
78
79class PipelineOpaqueAccessRewriter {
80 public:
81 /*!
82 * \brief Constructor
83 * \param buffer_data_to_buffer The map from buffer data to buffer.
84 * \param buffer_remap The map from original buffer to the buffer with updated shape for
85 * multi-versioning in the software pipeline.
86 * \param pipeline_loop The original loop to be software pipelined.
87 * \param fragment_info Information about tensor core fragment
88 */
89 PipelineOpaqueAccessRewriter(
90 const Map<Var, Buffer>& buffer_data_to_buffer, const Map<Buffer, Buffer>& buffer_remap,
91 const For& pipeline_loop,
92 const std::unordered_map<const VarNode*, FragmentInfo>& fragment_info)
93 : buffer_data_to_buffer_(buffer_data_to_buffer),
94 buffer_remap_(buffer_remap),
95 pipeline_loop_(pipeline_loop),
96 fragment_info_(fragment_info) {}
97
98 PrimExpr Rewrite(const Call& call) {
99 // Intrinsic calls should be handled explicitly here as they are opaque accesses to
100 // buffer.
101 static const auto& load_matrix_sync = builtin::tvm_load_matrix_sync();
102 static const auto& store_matrix_sync = builtin::tvm_store_matrix_sync();
103 static const auto& mma_sync = builtin::tvm_mma_sync();
104 static const auto& access_ptr = builtin::tvm_access_ptr();
105 static const auto& ptx_ldmatrix = builtin::ptx_ldmatrix();
106 static const auto& ptx_mma = builtin::ptx_mma();
107 if (call->op.same_as(load_matrix_sync) || call->op.same_as(store_matrix_sync)) {
108 const Buffer& buffer = buffer_data_to_buffer_.at(Downcast<Var>(call->args[0]));
109 auto it = buffer_remap_.find(buffer);
110 if (it != buffer_remap_.end()) {
111 Array<PrimExpr> new_args = call->args;
112 const Buffer& new_buffer = (*it).second;
113 new_args.Set(4, RewriteWmmaFragmentIndex(buffer, new_buffer, call->args[4]));
114 return Call(call->dtype, call->op, new_args, call->span);
115 }
116 } else if (call->op.same_as(mma_sync)) {
117 Array<PrimExpr> new_args = call->args;
118 for (int i = 0; i < 4; i++) {
119 const Var& buffer_var = Downcast<Var>(call->args[i * 2]);
120 const PrimExpr& index = call->args[i * 2 + 1];
121 const Buffer& buffer = buffer_data_to_buffer_.at(buffer_var);
122 auto it = buffer_remap_.find(buffer);
123 if (it != buffer_remap_.end()) {
124 PrimExpr new_index = RewriteWmmaFragmentIndex(buffer, (*it).second, index);
125 new_args.Set(i * 2 + 1, new_index);
126 }
127 }
128 return Call(call->dtype, call->op, new_args, call->span);
129 } else if (call->op.same_as(access_ptr)) {
130 return RewriteBufferAccess(call, {1});
131 } else if (call->op.same_as(ptx_mma)) {
132 return RewriteBufferAccess(call, {6, 8, 10});
133 } else if (call->op.same_as(ptx_ldmatrix)) {
134 return RewriteBufferAccess(call, {3});
135 }
136 return call;
137 }
138
139 private:
140 int GetWmmaFragmentSize(const Buffer& buffer) {
141 auto it = fragment_info_.find(buffer->data.get());
142 ICHECK(it != fragment_info_.end());
143 const FragmentInfo& info = (*it).second;
144 return info.GetSize();
145 }
146
147 PrimExpr RewriteWmmaFragmentIndex(const Buffer& old_buffer, const Buffer& new_buffer,
148 const PrimExpr& old_index) {
149 PrimExpr new_buffer_offset = old_index;
150
151 int fragment_size = GetWmmaFragmentSize(old_buffer);
152 PrimExpr offset =
153 floordiv(foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); },
154 make_const(DataType::Int(32), 1), old_buffer->shape),
155 fragment_size);
156 new_buffer_offset +=
157 floormod(pipeline_loop_->loop_var - pipeline_loop_->min, new_buffer->shape[0]) * offset;
158 return new_buffer_offset;
159 }
160
161 PrimExpr RewriteBufferAccess(const Call& call, const std::vector<int> arg_indices) {
162 auto product = [](const Array<PrimExpr>& input) {
163 return foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); },
164 make_const(DataType::Int(32), 1), input);
165 };
166 Array<PrimExpr> new_args = call->args;
167 for (int i : arg_indices) {
168 const Buffer& buffer = buffer_data_to_buffer_.at(Downcast<Var>(call->args[i]));
169 auto it = buffer_remap_.find(buffer);
170 if (it != buffer_remap_.end()) {
171 const Buffer& new_buffer = (*it).second;
172 const PrimExpr& old_index = call->args[i + 1];
173 PrimExpr offset;
174 if (new_buffer->strides.empty()) {
175 offset = product(buffer->shape);
176 } else {
177 offset = new_buffer->strides[0];
178 }
179 PrimExpr new_index =
180 old_index + floormod(pipeline_loop_->loop_var, new_buffer->shape[0]) * offset;
181 new_args.Set(i + 1, new_index);
182 }
183 }
184 return Call(call->dtype, call->op, new_args, call->span);
185 }
186
187 const Map<Var, Buffer>& buffer_data_to_buffer_;
188 const Map<Buffer, Buffer>& buffer_remap_;
189 const For& pipeline_loop_;
190 const std::unordered_map<const VarNode*, FragmentInfo>& fragment_info_;
191};
192
193/*!
194 * \brief Rewriter for the body of the software pipeline. This pass inserts `floormod` to indices
195 * of the remapped buffer to select the version corresponding to the pipeline stage.
196 */
197class PipelineBodyRewriter : public StmtExprMutator {
198 public:
199 /*!
200 * \brief Constructor of PipelineBodyRewriter.
201 * \param buffer_data_to_buffer The map from buffer data to buffer.
202 * \param buffer_remap The map from original buffer to the buffer with updated shape for
203 * multi-versioning in the software pipeline.
204 * \param pipeline_loop The original loop to be software pipelined.
205 * \param access_all_versions Whether all versions the buffers in the software pipeline are
206 * accessed. This will be used to update block access region. In the prologue and epilogue
207 * of a two-stage software pipeline, only one version of these buffers are accessed.
208 * \param fragment_info Information about tensor core fragment
209 */
210 PipelineBodyRewriter(const Map<Var, Buffer>& buffer_data_to_buffer,
211 const Map<Buffer, Buffer>& buffer_remap, For pipeline_loop,
212 bool access_all_versions,
213 const std::unordered_map<const VarNode*, FragmentInfo>& fragment_info)
214 : buffer_data_to_buffer_(buffer_data_to_buffer),
215 buffer_remap_(buffer_remap),
216 pipeline_loop_(pipeline_loop),
217 access_all_versions_(access_all_versions),
218 opaque_access_rewriter_(buffer_data_to_buffer_, buffer_remap_, pipeline_loop_,
219 fragment_info) {}
220
221 private:
222 BufferRegion RewritePipelineBufferRegion(const BufferRegion& buffer_region) const {
223 auto it = buffer_remap_.find(buffer_region->buffer);
224 if (it != buffer_remap_.end()) {
225 Region new_region = buffer_region->region;
226 const Buffer& new_buffer = (*it).second;
227 // For pipeline buffers, relax the access region of the first dimension to full extent
228 // if access_all_versions == true
229 Range accessed_version =
230 access_all_versions_
231 ? Range::FromMinExtent(0, new_buffer->shape[0])
232 : Range::FromMinExtent(floormod((pipeline_loop_->loop_var - pipeline_loop_->min),
233 new_buffer->shape[0]),
234 Integer(1));
235 new_region.insert(new_region.begin(), accessed_version);
236 return BufferRegion(new_buffer, new_region);
237 }
238 return buffer_region;
239 }
240
241 Stmt VisitStmt_(const BlockNode* op) final {
242 for (const Buffer& alloc_buffer : op->alloc_buffers) {
243 buffer_data_to_buffer_.Set(alloc_buffer->data, alloc_buffer);
244 }
245 Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
246 BlockNode* n = block.CopyOnWrite();
247 n->reads.MutateByApply([this](const BufferRegion& buffer_region) {
248 return RewritePipelineBufferRegion(buffer_region);
249 });
250 n->writes.MutateByApply([this](const BufferRegion& buffer_region) {
251 return RewritePipelineBufferRegion(buffer_region);
252 });
253 for (const Buffer& alloc_buffer : op->alloc_buffers) {
254 buffer_data_to_buffer_.erase(alloc_buffer->data);
255 }
256 return std::move(block);
257 }
258
259 Stmt VisitStmt_(const BufferStoreNode* op) final {
260 BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
261 auto it = buffer_remap_.find(store->buffer);
262 if (it == buffer_remap_.end()) {
263 return std::move(store);
264 }
265 const Buffer& new_buffer = (*it).second;
266 auto* n = store.CopyOnWrite();
267 n->buffer = new_buffer;
268 PrimExpr version =
269 floormod((pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]);
270 n->indices.insert(n->indices.begin(), version);
271 return std::move(store);
272 }
273
274 PrimExpr VisitExpr_(const BufferLoadNode* op) final {
275 BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
276 auto it = buffer_remap_.find(load->buffer);
277 if (it == buffer_remap_.end()) {
278 return std::move(load);
279 }
280 const Buffer& new_buffer = (*it).second;
281 auto* n = load.CopyOnWrite();
282 n->buffer = new_buffer;
283 PrimExpr version =
284 floormod((pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]);
285 n->indices.insert(n->indices.begin(), version);
286 return std::move(load);
287 }
288
289 PrimExpr VisitExpr_(const CallNode* op) final {
290 Call call = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
291 return opaque_access_rewriter_.Rewrite(call);
292 }
293
294 Map<Var, Buffer> buffer_data_to_buffer_;
295 Map<Buffer, Buffer> buffer_remap_;
296 For pipeline_loop_;
297 bool access_all_versions_;
298 PipelineOpaqueAccessRewriter opaque_access_rewriter_;
299};
300
301/*!
302 * \brief Rewriter for the software pipeline that rewrite a loop into a pipelined one.
303 */
304class PipelineRewriter : public StmtExprMutator {
305 public:
306 static Stmt Rewrite(
307 Map<Var, Buffer> buffer_data_to_buffer,
308 const std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual>& double_buffers,
309 const Array<Buffer> pipeline_allocs, const For& pipeline_loop,
310 const PipelineInfo& pipeline_info,
311 const std::unordered_map<const VarNode*, FragmentInfo>& fragment_info,
312 const Map<String, ObjectRef> preserved_annotations, bool merge_async_commit_queue_scope) {
313 PipelineRewriter rewriter(buffer_data_to_buffer, double_buffers, pipeline_allocs, pipeline_loop,
314 pipeline_info, fragment_info, preserved_annotations,
315 merge_async_commit_queue_scope);
316 return rewriter.BuildPipeline();
317 }
318
319 private:
320 PipelineRewriter(Map<Var, Buffer> buffer_data_to_buffer,
321 const std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual>& double_buffers,
322 const Array<Buffer>& pipeline_allocs, const For& pipeline_loop,
323 const PipelineInfo& pipeline_info,
324 const std::unordered_map<const VarNode*, FragmentInfo>& fragment_info,
325 const Map<String, ObjectRef> preserved_annotations,
326 bool merge_async_commit_queue_scope)
327
328 : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)),
329 double_buffers_(double_buffers),
330 pipeline_allocs_(pipeline_allocs),
331 pipeline_loop_(pipeline_loop),
332 pipeline_info_(pipeline_info),
333 fragment_info_(fragment_info),
334 preserved_annotations_(preserved_annotations),
335 merge_async_commit_queue_scope_(merge_async_commit_queue_scope) {}
336
337 Stmt BuildPipeline() {
338 // Step 1: Analyze accesses to the buffers in the pipeline and compute the number of versions
339 // need to maintain for each buffer.
340 std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual> infos =
341 GetBufferAccessInfo();
342 for (const Buffer& buffer : pipeline_allocs_) {
343 int num_versions = ComputeBufferVersions(buffer, infos.at(buffer));
344 if (num_versions > 1) {
345 buffer_remap_.Set(buffer, RewriteAllocBuffer(buffer, num_versions));
346 }
347 }
348
349 ordered_stmts_.resize(pipeline_info_.size());
350 for (const auto& pair : pipeline_info_) {
351 const Block& block = pair.first;
352 int order = pair.second.order;
353 ordered_stmts_.Set(order, block);
354 }
355
356 // Step 2: Emit the pipeline prologue, body and epilogue.
357 Stmt prologue = EmitImpl(pipeline_loop_->min, pipeline_loop_->min + max_stage_, true);
358 Stmt body = EmitImpl(pipeline_loop_->min + max_stage_,
359 pipeline_loop_->min + pipeline_loop_->extent, false);
360 Stmt epilogue = EmitImpl(pipeline_loop_->min + pipeline_loop_->extent,
361 pipeline_loop_->min + pipeline_loop_->extent + max_stage_, true);
362
363 SeqStmt stmt = SeqStmt({prologue, body, epilogue});
364
365 // Step 3: Make a new block that contains new buffer allocations after pipeline rewriting.
366 Array<Buffer> alloc_buffers;
367 for (const auto& alloc : pipeline_allocs_) {
368 alloc_buffers.push_back(buffer_remap_.Get(alloc).value_or(alloc));
369 buffer_data_to_buffer_.erase(alloc->data);
370 }
371 Block block = MakeBlock(stmt, buffer_data_to_buffer_);
372 block.CopyOnWrite()->alloc_buffers = std::move(alloc_buffers);
373 return BlockRealize({}, Bool(true), block);
374 }
375
376 private:
377 /*!
378 * \brief Analyze accesses to the buffers in the software pipeline.
379 *
380 * This method check the 'define' and 'use' stage of the buffers in the software pipeline, which
381 * can be used to compute the number of versions needed to maintain after rewriting.
382 */
383 std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual>
384 GetBufferAccessInfo() {
385 std::unordered_map<Buffer, BufferAccessInfo, ObjectPtrHash, ObjectPtrEqual> infos;
386 for (const auto& pair : pipeline_info_) {
387 const Block& block = pair.first;
388 int stage = pair.second.stage;
389 max_stage_ = std::max(max_stage_, stage);
390
391 for (const BufferRegion& write : block->writes) {
392 if (!infos.count(write->buffer)) {
393 infos.emplace(write->buffer, BufferAccessInfo{});
394 }
395 auto& info = infos.at(write->buffer);
396 if (info.def == -1) {
397 info.def = stage;
398 } else {
399 info.def = std::min(info.def, stage);
400 }
401 }
402
403 for (const BufferRegion& read : block->reads) {
404 if (!infos.count(read->buffer)) {
405 infos.emplace(read->buffer, BufferAccessInfo{});
406 }
407 auto& info = infos.at(read->buffer);
408 info.use = std::max(info.use, stage);
409 }
410 }
411 return infos;
412 }
413
414 /*!
415 * \brief Check whether two regions have intersections.
416 * \param region1 The first region.
417 * \param region2 The second region.
418 * \return Whether region1 and region2 have intersections.
419 */
420 bool MayConflict(Region region1, Region region2) {
421 ICHECK(region1.size() == region2.size());
422 for (size_t i = 0; i < region1.size(); i++) {
423 Range dim1 = region1[i];
424 Range dim2 = region2[i];
425 auto int_set1 = arith::IntSet::FromRange(dim1);
426 auto int_set2 = arith::IntSet::FromRange(dim2);
427 if (arith::Intersect({int_set1, int_set2}).IsNothing()) {
428 return false;
429 }
430 }
431 return true;
432 }
433
434 /*!
435 * \brief Compute the number of versions need to maintain for buffer accessed in the software
436 * pipeline.
437 *
438 * This method applies liveness analysis to the target buffer to compute the number of versions
439 * need to maintain during the software pipeline.
440 * Annotation `attr::double_buffer_scope` is handled here which provides a way to override the
441 * result of the analysis. Additional double buffering in the software pipeline can be useful
442 * to eliminate synchronizations in GPU devices.
443 *
444 * \param buffer The target buffer
445 * \param buffer_info The access information of the target buffer.
446 * \return The number of versions required for the target buffer.
447 */
448 int ComputeBufferVersions(const Buffer& buffer, const BufferAccessInfo& buffer_info) {
449 if (buffer_info.def == -1) {
450 // Keep the original number of versions as buffers defined outside the software pipeline
451 // should not be mutated.
452 return 1;
453 }
454
455 // `use - def + 1` is a upper bound of the needed versions
456 // We optimize a few case where the number of versions can be smaller than the upper bound
457 int num_versions = buffer_info.use - buffer_info.def + 1;
458 if (num_versions == 2) {
459 // A special case when `use - def + 1 == 2`. Double buffering is only needed in this case when
460 // these exists a reader block_i and a writer block_j such that
461 // order(block_i) < order(block_j) and stage(block_i) < stage(block_j) and the access regions
462 // of block_i and block_j overlap.
463 bool need_multi_version = false;
464 for (const auto& pair1 : pipeline_info_) {
465 const Block& writer_block = pair1.first;
466 const auto& writer_info = pair1.second;
467
468 auto it1 = std::find_if(writer_block->writes.begin(), writer_block->writes.end(),
469 [&](const BufferRegion& buffer_region) {
470 return buffer_region->buffer.same_as(buffer);
471 });
472 if (it1 == writer_block->writes.end()) {
473 continue;
474 }
475
476 for (const auto& pair2 : pipeline_info_) {
477 const Block& reader_block = pair2.first;
478 const auto& reader_info = pair2.second;
479 auto it2 = std::find_if(reader_block->reads.begin(), reader_block->reads.end(),
480 [&](const BufferRegion& buffer_region) {
481 return buffer_region->buffer.same_as(buffer);
482 });
483 if (it2 == reader_block->reads.end()) {
484 continue;
485 }
486 if (writer_info.order < reader_info.order && writer_info.stage < reader_info.stage &&
487 MayConflict((*it1)->region, (*it2)->region)) {
488 need_multi_version = true;
489 break;
490 }
491 }
492 }
493 if (!need_multi_version) {
494 num_versions = 1;
495 }
496 }
497 if (num_versions == 1 && double_buffers_.count(buffer)) {
498 num_versions = 2;
499 }
500 return num_versions;
501 }
502
503 /*!
504 * \brief Rewrite buffer allocation to keep multiple versions of original buffer for pipelined
505 * accesses.
506 * \param buffer The buffer to be resized.
507 * \param num_versions The number of versions to keep.
508 * \return The resized buffer.
509 */
510 Buffer RewriteAllocBuffer(const Buffer& buffer, int num_versions) {
511 ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*(buffer.get()));
512 new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions));
513 if (new_buffer->strides.size()) {
514 ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size());
515 PrimExpr stride_0 = new_buffer->strides[0] * new_buffer->shape[1];
516 new_buffer->strides.insert(new_buffer->strides.begin(), stride_0);
517 }
518 return Buffer(new_buffer);
519 }
520
521 // Per-stage states that need to be tracked across pipeline prologue, body, and epilogue.
522 struct AsyncStateGlobal {
523 // Buffers that this stage asynchronously writes.
524 std::unordered_set<const BufferNode*> dst_buffers;
525 // An imaginary index that the latest async operation associated with this stage has written
526 // into. Only valid if all associated predicates are true, so that we can count the number of
527 // async invocations exactly. When it is valid, it is the "sum of extents of loops that have
528 // been executed" - 1, e.g. for epilogue it is prologue extent + body extent - 1. This
529 // is only needed to compute wait count for epilogue without async producers.
530 Optional<PrimExpr> producer_head{PrimExpr(-1)};
531
532 bool writes(Buffer buf) const { return dst_buffers.count(buf.get()) > 0; }
533 };
534
535 // Per-stage states that are local to each of pipeline prologue, body, and epilogue.
536 struct AsyncStateLocal {
537 struct {
538 // The index into a list of blocks, where async_wait_queue should be attached at the
539 // beginning.
540 int insert_before;
541 // in_flight_count would be a more precise name, but the implementation uses wait_count for
542 // brevity.
543 PrimExpr wait_count{nullptr};
544
545 bool valid() const { return wait_count.defined(); }
546 } pending_wait;
547
548 // Destination buffers of async operations that have been encountered so far in the loop
549 //
550 // for (size_t i = 0; i < new_blocks.size(); ++i) {
551 // ...
552 // }
553 //
554 // This is for tracking which async operations have been issued at the "current" iteration, up
555 // until a point where we encounter a consumer of async result buffers. This is used to decide
556 // if the producer_head of each buffer points to a copy written in the current or previous
557 // iteration.
558 std::unordered_set<const BufferNode*> seen;
559
560 // A symbolic expression representing the index the latest async operation associated with this
561 // stage has written into, at the "current" iteration.
562 Optional<PrimExpr> producer_head;
563 // The predicate of BlockRealize containing the async operation of this stage.
564 Optional<PrimExpr> predicate;
565 // Indices into a list of blocks, where async_commit_queue scope should be attached.
566 // If multiple async producers are interleaved with their consumer in between, we need separate
567 // async_commit_queue for each producer. Thus, we need multiple sets of indices.
568 std::vector<std::vector<size_t>> commit_groups;
569
570 // This is set to true when we reach a stage that consumes this async stage.
571 bool consumed{false};
572 };
573
574 /*! Structure holding intermediate information for pipeline loop rewriting. */
575 struct RewrittenBlockInfo {
576 int stage;
577 PrimExpr predicate;
578 Block block;
579 PrimExpr access_index;
580 bool is_async;
581 };
582
583 // Determine where to insert async_wait and the corresponding wait count.
584 void PopulateWaitCounts(const std::vector<RewrittenBlockInfo>& new_blocks,
585 arith::Analyzer* ana_normalized,
586 const std::unordered_map<const BufferNode*, int>& buffer_to_commit_group,
587 std::map<int, AsyncStateLocal>* async_states_local) {
588 for (size_t i = 0; i < new_blocks.size(); ++i) {
589 if (new_blocks[i].is_async) {
590 // Record the fact that we have encountered these write buffers.
591 for (auto write_region : new_blocks[i].block->writes) {
592 (*async_states_local)[new_blocks[i].stage].seen.insert(write_region->buffer.get());
593 }
594 }
595
596 int producer_stage_idx = -1;
597 for (auto read_region : new_blocks[i].block->reads) {
598 for (auto kv : async_states) {
599 if (kv.first <= new_blocks[i].stage && kv.second.writes(read_region->buffer)) {
600 // Found an earlier stage where read_region->buffer was asynchronously written
601 ICHECK(producer_stage_idx == -1 || producer_stage_idx == kv.first)
602 << "A dependency on multiple async stages is not supported";
603 producer_stage_idx = kv.first;
604 }
605 }
606 }
607
608 if (producer_stage_idx == -1) continue;
609
610 // The following logic has become complicated to handle case like this:
611 //
612 // for i in range(13):
613 // # Stage 0
614 // async_commit_queue(0):
615 // async_scope:
616 // A_shared[(i + 3) % 4] = A[...]
617 //
618 //
619 // # Stage 1
620 // async_wait_queue(0, 5):
621 // compute(A_shared[i], B_shared[i])
622 //
623 // # Stage 0
624 // async_commit_queue(0)
625 // async_scope:
626 // B_shared[(i + 3) % 4] = B[...]
627 //
628 //
629 // Here, multiple async producers in the same stage are interleaved with their consumer in
630 // between. Since each buffer is associated with different commit groups, the wait_count
631 // before the consumer should be bigger than the simpler case:
632 //
633 // for i in range(13):
634 // # Stage 0
635 // async_commit_queue(0):
636 // async_scope:
637 // A_shared[(i + 3) % 4] = A[...]
638 // B_shared[(i + 3) % 4] = B[...]
639 //
640 // # Stage 1
641 // async_wait_queue(0, 3):
642 // compute(A_shared[i], B_shared[i])
643 //
644 // The correct wait_count can be determined by considering each commit group separately, and
645 // summing "per-commit" wait_counts.
646 //
647 // From A_shared's perspective, it allows for (i + 3) - i async commit groups to be in
648 // flight while from B_shared's perspective, the producer head at compute points to the copy
649 // done by the previous iteration, so its wait_count is calculated as ((i - 1) + 3) - i. The
650 // sum of the two wait_counts gives 5.
651
652 auto& dep_local_state = (*async_states_local)[producer_stage_idx];
653 const auto num_commit_group = dep_local_state.commit_groups.size();
654 std::vector<Optional<PrimExpr>> producer_head_per_commit;
655
656 if (num_commit_group == 0) {
657 // Epilogue, no async producer. Since "local" producer_head is not available, use
658 // "global" producer_head.
659 ICHECK(!dep_local_state.producer_head);
660 producer_head_per_commit.push_back(async_states[producer_stage_idx].producer_head);
661 } else {
662 ICHECK(dep_local_state.producer_head);
663 std::vector<bool> need_wait_count(num_commit_group, true);
664
665 for (auto read_region : new_blocks[i].block->reads) {
666 if (!async_states[producer_stage_idx].writes(read_region->buffer)) continue;
667 auto commit_group_id = buffer_to_commit_group.at(read_region->buffer.get());
668 if (!need_wait_count[commit_group_id]) continue;
669
670 if (!dep_local_state.seen.count(read_region->buffer.get())) {
671 // Multiple async producers interleaved: The most recent async write is from the
672 // previous iteration. This is the B_shared case above.
673 producer_head_per_commit.push_back(dep_local_state.producer_head.value() - 1);
674 } else {
675 // Normal case
676 producer_head_per_commit.push_back(dep_local_state.producer_head.value());
677 }
678
679 need_wait_count[commit_group_id] = false;
680 }
681 }
682
683 auto wait_count = [=, &ana_normalized]() {
684 auto sum = PrimExpr(0);
685 for (auto producer_head : producer_head_per_commit) {
686 if (producer_head && ana_normalized->CanProve(producer_head.value() >= 0)) {
687 // Here, new_blocks[i].access_index corresponds to "consumer_head".
688 // The difference of producer_head and consumer_head is precisely the number of
689 // async commit groups that can still be in flight after this wait.
690 sum += analyzer_.Simplify(producer_head.value() - new_blocks[i].access_index);
691 } else {
692 // The precise count cannot be determined, give up.
693 return PrimExpr(0);
694 }
695 }
696 return sum;
697 }();
698
699 auto& pending_wait = dep_local_state.pending_wait;
700
701 if (!pending_wait.valid()) {
702 pending_wait = {static_cast<int>(i), wait_count};
703 } else if (analyzer_.CanProve(wait_count < pending_wait.wait_count)) {
704 // Coalesce multiple wait_queue if the later one allows fewer in-flight ops.
705 pending_wait = {pending_wait.insert_before, wait_count};
706 }
707 }
708 }
709
710 // Given pipelined blocks and async-related information, generate final loop statements with async
711 // scopes (if any).
712 Array<Stmt> CompletePipelineLoopStatements(
713 const std::vector<RewrittenBlockInfo>& blocks,
714 const std::map<int, AsyncStateLocal>& async_states_local,
715 arith::Analyzer* ana_normalized) const {
716 std::vector<RewrittenBlockInfo> new_blocks = blocks;
717 std::vector<int> commit_group_indices(new_blocks.size(), -1);
718 for (const auto& [stage_id, state] : async_states_local) {
719 if (!state.commit_groups.empty()) {
720 for (size_t i = 0; i < state.commit_groups.size(); ++i) {
721 for (size_t j = 0; j < state.commit_groups[i].size(); ++j) {
722 ICHECK(state.commit_groups[i][0] + j < new_blocks.size());
723 commit_group_indices[state.commit_groups[i][0] + j] = stage_id;
724 }
725 }
726 }
727
728 if (state.pending_wait.valid()) {
729 auto attach_wait_scope = [&new_blocks](int i, int stage_id, PrimExpr wait_count) {
730 auto& block = new_blocks[i].block;
731 BlockNode* n = block.CopyOnWrite();
732 auto zero = make_zero(DataType::Int(32));
733 n->body =
734 AttrStmt(zero, tir::attr::async_wait_queue_scope, stage_id,
735 AttrStmt(zero, tir::attr::async_wait_inflight_count, wait_count, n->body));
736 };
737
738 if (state.predicate && !ana_normalized->CanProve(state.predicate.value())) {
739 // If the async operation that this wait_queue is waiting on is predicated, and we cannot
740 // prove that the predicate is always true, the precise wait count is only valid
741 // at iterations where the predicate is true;
742 auto wait_count = Call(DataType::Int(32), builtin::if_then_else(),
743 {state.predicate.value(), state.pending_wait.wait_count, 0});
744 attach_wait_scope(state.pending_wait.insert_before, stage_id, wait_count);
745 } else {
746 attach_wait_scope(state.pending_wait.insert_before, stage_id,
747 state.pending_wait.wait_count);
748 }
749 }
750 }
751
752 Array<Stmt> stmts;
753
754 for (size_t i = 0; i < new_blocks.size();) {
755 if (commit_group_indices[i] == -1) {
756 // A synchrnous block, not part of any commit group
757 stmts.push_back(BlockRealize({}, new_blocks[i].predicate, new_blocks[i].block));
758 ++i;
759 } else {
760 Array<Stmt> group_bodies;
761 auto stage_id = commit_group_indices[i];
762 auto predicate = new_blocks[i].predicate;
763 for (; i < commit_group_indices.size() && commit_group_indices[i] == stage_id; ++i) {
764 ICHECK(tvm::StructuralEqual()(predicate, new_blocks[i].predicate))
765 << "Predicates in the same stage are expected to be identical";
766 group_bodies.push_back(new_blocks[i].block->body);
767 }
768
769 if (merge_async_commit_queue_scope_ && group_bodies.size() > 1) {
770 auto merged_bodies = SeqStmt(group_bodies);
771 group_bodies.clear();
772 group_bodies.push_back(merged_bodies);
773 }
774
775 for (auto body : group_bodies) {
776 auto commit_queue_scope = AttrStmt(make_zero(DataType::Int(32)),
777 tir::attr::async_commit_queue_scope, stage_id, body);
778 auto new_block = MakeBlock(commit_queue_scope, buffer_data_to_buffer_);
779 stmts.push_back(BlockRealize({}, predicate, new_block));
780 }
781 }
782 }
783
784 return stmts;
785 }
786
787 /*!
788 * \brief Emit the pipeline loop in the given range.
789 * \param start The start of the range
790 * \param end The end of the range
791 * \param unroll_loop Whether the loop should be unrolled.
792 * \return The result loop.
793 */
794 Stmt EmitImpl(PrimExpr start, PrimExpr end, bool unroll_loop) {
795 PrimExpr new_loop_var;
796 PrimExpr extent = end - start;
797
798 auto make_nop = []() { return BlockRealize({}, Bool(true), MakeBlock(Evaluate(0), {})); };
799
800 if (!analyzer_.CanProve(extent > 0)) {
801 return make_nop();
802 }
803 bool is_unit_loop = analyzer_.CanProveEqual(extent, 1);
804 if (is_unit_loop) {
805 new_loop_var = start; // use constants as the loop var for unit loops
806 } else {
807 new_loop_var = pipeline_loop_->loop_var.copy_with_suffix("");
808 analyzer_.Bind(Downcast<Var>(new_loop_var), Range(start, end));
809 }
810
811 // In contrast to analyzer_ which is bound to [start, end), this one is bound to
812 // the "normalized" range, [pipeline_loop_->min, extent).
813 arith::Analyzer ana_normalized;
814 if (!is_unit_loop) {
815 ana_normalized.Bind(Downcast<Var>(new_loop_var), Range(pipeline_loop_->min, extent));
816 }
817
818 std::vector<RewrittenBlockInfo> new_blocks;
819
820 // Async related
821 std::map<int, AsyncStateLocal> async_states_local;
822 std::unordered_map<const BufferNode*, int> buffer_to_commit_group;
823
824 for (const Block& block : ordered_stmts_) {
825 int stage = pipeline_info_.at(block).stage;
826 PrimExpr skewed_loop_var = new_loop_var - stage;
827 PrimExpr inbound = analyzer_.Simplify(pipeline_loop_->min <= skewed_loop_var) &&
828 (skewed_loop_var < pipeline_loop_->min + pipeline_loop_->extent);
829 if (analyzer_.CanProve(!inbound)) {
830 continue;
831 }
832 Block new_block = Downcast<Block>(PipelineBodyRewriter(buffer_data_to_buffer_, buffer_remap_,
833 pipeline_loop_, max_stage_ != 1,
834 fragment_info_)(block));
835
836 PrimExpr delta = start - pipeline_loop_->min;
837 // This variable corresponds to
838 // - "producer_head" if this stage is an async producer
839 // - "consumer_head" if this stage reads from asynchronously written buffers.
840 PrimExpr normalized_access_index = is_unit_loop ? skewed_loop_var : skewed_loop_var + delta;
841
842 // Adjust the block predicate and the body according to the final loop bound
843 // [pipeline_loop_->min, extent).
844 if (!is_unit_loop) {
845 Var loop_iter = Downcast<Var>(new_loop_var);
846 inbound = Substitute(inbound, {{loop_iter, loop_iter + delta}});
847 }
848
849 new_block = Downcast<Block>(
850 Substitute(new_block, {{pipeline_loop_->loop_var, normalized_access_index}}));
851
852 if (pipeline_info_[block].async) {
853 auto& local_state = async_states_local[stage];
854
855 int commit_group_id = -1;
856 if (local_state.commit_groups.empty() || local_state.consumed ||
857 !merge_async_commit_queue_scope_) {
858 // consumed == true means there is already a consumer stage waiting for an
859 // eariler async operation of this stage. In such cases, we make multiple commit_queue
860 // for this stage.
861 commit_group_id = local_state.commit_groups.size();
862 local_state.commit_groups.push_back({new_blocks.size()});
863 } else {
864 // This is the case when one commit_queue groups multiple async blocks.
865 // with commit_queue(stage):
866 // async_scope:
867 // A_shared[...] = ...
868 // async_scope:
869 // B_shared[...] = ...
870
871 commit_group_id = local_state.commit_groups.size() - 1;
872 local_state.commit_groups.back().push_back(new_blocks.size());
873 }
874
875 for (auto write_region : new_block->writes) {
876 async_states[stage].dst_buffers.insert(write_region->buffer.get());
877 buffer_to_commit_group[write_region->buffer.get()] = commit_group_id;
878 }
879
880 local_state.producer_head = normalized_access_index;
881
882 if (!local_state.predicate || ana_normalized.CanProve(local_state.predicate.value())) {
883 local_state.predicate = inbound;
884 } else if (local_state.predicate) {
885 local_state.predicate = ana_normalized.Simplify(local_state.predicate.value() & inbound);
886 }
887
888 BlockNode* n = new_block.CopyOnWrite();
889 n->body = AttrStmt(make_zero(DataType::Int(32)), tir::attr::async_scope, 1, n->body);
890 }
891
892 new_blocks.push_back(
893 {stage, inbound, new_block, normalized_access_index, pipeline_info_[block].async});
894
895 for (auto read_region : new_block->reads) {
896 for (auto kv : async_states) {
897 int producer_stage_id = kv.first;
898 if (producer_stage_id <= stage && kv.second.writes(read_region->buffer)) {
899 async_states_local[producer_stage_id].consumed = true;
900 }
901 }
902 }
903 }
904
905 PopulateWaitCounts(new_blocks, &ana_normalized, buffer_to_commit_group, &async_states_local);
906 auto stmts = CompletePipelineLoopStatements(new_blocks, async_states_local, &ana_normalized);
907
908 Stmt new_loop{nullptr};
909
910 if (stmts.empty()) {
911 return make_nop();
912 }
913 if (stmts.size() == 1) {
914 new_loop = stmts[0];
915 } else {
916 new_loop = SeqStmt(stmts);
917 }
918
919 if (!is_unit_loop) {
920 new_loop = For(Downcast<Var>(new_loop_var), pipeline_loop_->min, extent,
921 unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind, std::move(new_loop),
922 NullOpt, preserved_annotations_);
923 }
924
925 // Update producer heads in the global async states.
926 for (const auto& kv : async_states_local) {
927 const int stage_id = kv.first;
928 const AsyncStateLocal& state = kv.second;
929
930 if (state.predicate && ana_normalized.CanProve(state.predicate.value()) &&
931 async_states[stage_id].producer_head) {
932 // Advance the "global" producer head if it is still valid and we know exactly how much we
933 // can increment
934 async_states[stage_id].producer_head =
935 async_states[stage_id].producer_head.value() + extent;
936 } else {
937 // Otherwise, invalidate the global producer head
938 async_states[stage_id].producer_head = NullOpt;
939 }
940 }
941
942 return BlockRealize({}, Bool(true), MakeBlock(std::move(new_loop), buffer_data_to_buffer_));
943 }
944
945 arith::Analyzer analyzer_;
946 Map<Var, Buffer> buffer_data_to_buffer_;
947 const std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual>& double_buffers_;
948 Array<Buffer> pipeline_allocs_;
949 For pipeline_loop_;
950 PipelineInfo pipeline_info_;
951 const std::unordered_map<const VarNode*, FragmentInfo>& fragment_info_;
952 int max_stage_ = -1;
953 Map<Buffer, Buffer> buffer_remap_;
954 Array<Block> ordered_stmts_;
955 std::map<int, AsyncStateGlobal> async_states;
956 Map<String, ObjectRef> preserved_annotations_;
957 bool merge_async_commit_queue_scope_ = true;
958};
959
960/*!
961 * \brief Build the dependency graph among a array of blocks.
962 * \param[in] blocks The array of blocks.
963 * \param[out] dep_src2dst Optional, a map to store dependency edges from the source to the
964 * destination.
965 * \param[out] dep_dst2src Optional, a map to store dependency edges from the
966 * destination to the source.
967 */
968void BuildDependencyGraph(
969 const Array<Block>& blocks,
970 std::unordered_map<Block, Array<Block>, ObjectPtrHash, ObjectPtrEqual>* dep_src2dst,
971 std::unordered_map<Block, Array<Block>, ObjectPtrHash, ObjectPtrEqual>* dep_dst2src) {
972 std::unordered_map<Var, Array<Block>, ObjectPtrHash, ObjectPtrEqual> buffer_writers;
973
974 for (const Block& block : blocks) {
975 for (const BufferRegion& read : block->reads) {
976 auto it = buffer_writers.find(read->buffer->data);
977 if (it != buffer_writers.end()) {
978 for (const Block& writer : it->second) {
979 if (dep_src2dst != nullptr) {
980 (*dep_src2dst)[writer].push_back(block);
981 }
982 if (dep_dst2src != nullptr) {
983 (*dep_dst2src)[block].push_back(writer);
984 }
985 }
986 }
987 }
988 for (const BufferRegion& write : block->writes) {
989 buffer_writers[write->buffer->data].push_back(block);
990 }
991 }
992}
993
994class PipelineInjector : private StmtExprMutator {
995 public:
996 static Stmt Inject(const PrimFunc& func, bool merge_async_commit_queue_scope) {
997 PipelineInjector injector(merge_async_commit_queue_scope);
998 for (const auto& kv : func->buffer_map) {
999 const Buffer& buffer = kv.second;
1000 injector.buffer_data_to_buffer_.Set(buffer->data, buffer);
1001 }
1002 injector.fragment_info_ = GetTensorCoreFragmentInfo(func->body);
1003 return injector(func->body);
1004 }
1005
1006 private:
1007 explicit PipelineInjector(bool merge_async_commit_queue_scope)
1008 : merge_async_commit_queue_scope_(merge_async_commit_queue_scope) {}
1009
1010 /*!
1011 * \brief Check the pipeline satisfies the following conditions:
1012 * 1. No conflicting order: The order of each statement should be unique.
1013 * 2. Reordering of statements doesn't break buffer access dependencies. Specifically, for
1014 * dependency (e.g. read-after-write) from statement A to statement B, it requires:
1015 * case 1: stage(A) < stage(B)
1016 * case 2: stage(A) == stage(B) and order(A) < order(B)
1017 */
1018 void ValidatePipelineBody(const PipelineInfo& pipeline_info, const Array<Block>& original_order) {
1019 std::unordered_set<int> used_orders;
1020 std::unordered_map<int, int> stage_max_order;
1021 std::unordered_map<int, const Block*> order_to_block;
1022 std::unordered_map<const Block*, int> block_to_stage;
1023 for (const Block& block : original_order) {
1024 const auto& stmt_info = pipeline_info.at(block);
1025 int order = stmt_info.order;
1026 CHECK(!used_orders.count(order))
1027 << "ValueError: Two statements in the software pipeline cannot have the same order";
1028 used_orders.insert(order);
1029 }
1030
1031 std::unordered_map<Block, Array<Block>, ObjectPtrHash, ObjectPtrEqual> dep_src2dst;
1032 BuildDependencyGraph(original_order, &dep_src2dst, nullptr);
1033
1034 for (const auto& pair : dep_src2dst) {
1035 const Block& src = pair.first;
1036 const auto& src_info = pipeline_info.at(src);
1037 const Array<Block>& dsts = pair.second;
1038 for (const Block& dst : dsts) {
1039 const auto& dst_info = pipeline_info.at(dst);
1040 CHECK_LE(src_info.stage, dst_info.stage)
1041 << "ValueError: statement " << dst << " in stage " << dst_info.stage
1042 << " cannot depends on statement " << src << " in a later stage " << src_info.stage;
1043 if (src_info.stage == dst_info.stage) {
1044 CHECK_LT(src_info.order, dst_info.order) << "ValueError: two statements with buffer "
1045 "access dependency in the same stage of the "
1046 "software pipeline cannot be reordered";
1047 }
1048 }
1049 }
1050 }
1051
1052 Stmt VisitStmt_(const ForNode* op) final {
1053 // Step 1: Recursively rewrite the children first.
1054 For for_node = Downcast<For>(StmtExprMutator::VisitStmt_(op));
1055 if (!HasPipelineAnnotation(op)) {
1056 return std::move(for_node);
1057 }
1058 // Step 2: Find the body and buffer allocations of the pipeline. The body can be direct child of
1059 // the for-loop. If the for-loop has BlockRealize as its child, the pipeline body will be the
1060 // child of the block.
1061 Stmt pipeline_body{nullptr};
1062 Array<Buffer> pipeline_allocs;
1063 if (const auto* realize = for_node->body.as<BlockRealizeNode>()) {
1064 const auto& block = realize->block;
1065 for (const auto& buffer : block->alloc_buffers) {
1066 ICHECK(buffer->IsInstance<BufferNode>());
1067 buffer_data_to_buffer_.Set(buffer->data, buffer);
1068 }
1069 pipeline_body = block->body;
1070 pipeline_allocs = block->alloc_buffers;
1071 } else {
1072 pipeline_body = for_node->body;
1073 }
1074
1075 const SeqStmtNode* pipeline_body_seq = pipeline_body.as<SeqStmtNode>();
1076 CHECK(pipeline_body_seq)
1077 << "ValueError: The body of the software pipeline should be SeqStmt, got "
1078 << pipeline_body->GetTypeKey();
1079
1080 // Step 3: Blockize the components of the pipeline. Each child of the pipelined loop will be
1081 // converted into a block.
1082 PipelineInfo pipeline_info;
1083 Array<Block> original_order; // pipeline body blocks in the original order
1084
1085 auto f_add_child = [&](const Stmt& child) {
1086 original_order.push_back(MakeBlock(child, buffer_data_to_buffer_));
1087 };
1088 for (size_t i = 0; i < pipeline_body_seq->seq.size(); i++) {
1089 const auto* nested_block_realize = pipeline_body_seq->seq[i].as<BlockRealizeNode>();
1090 if (nested_block_realize && is_one(nested_block_realize->predicate) &&
1091 nested_block_realize->block->body->IsInstance<SeqStmtNode>()) {
1092 const Block& nested_pipeline_block = nested_block_realize->block;
1093 ICHECK(
1094 nested_pipeline_block->match_buffers.empty()); // match_buffer should have been lowered
1095 for (const auto& buffer : nested_pipeline_block->alloc_buffers) {
1096 pipeline_allocs.push_back(buffer);
1097 buffer_data_to_buffer_.Set(buffer->data, buffer);
1098 }
1099 const auto* nested_seq = nested_pipeline_block->body.as<SeqStmtNode>();
1100 for (size_t j = 0; j < nested_seq->seq.size(); j++) {
1101 f_add_child(nested_seq->seq[j]);
1102 }
1103 } else {
1104 f_add_child(pipeline_body_seq->seq[i]);
1105 }
1106 }
1107
1108 auto pipeline_stages =
1109 Downcast<Array<Integer>>(op->annotations.at(attr::software_pipeline_stage));
1110 auto pipeline_orders =
1111 Downcast<Array<Integer>>(op->annotations.at(attr::software_pipeline_order));
1112 CHECK_EQ(pipeline_stages.size(), original_order.size());
1113 CHECK_EQ(pipeline_orders.size(), original_order.size());
1114
1115 std::unordered_set<int> pipeline_async_stages;
1116 if (auto annot = op->annotations.Get(attr::software_pipeline_async_stages)) {
1117 for (auto s : Downcast<Array<Integer>>(annot)) {
1118 pipeline_async_stages.insert(s->value);
1119 }
1120 }
1121
1122 Map<String, ObjectRef> preserved_annotations;
1123 for (const auto& kv : op->annotations) {
1124 const String& key = kv.first;
1125 if (kv.first != attr::software_pipeline_stage && kv.first != attr::software_pipeline_order &&
1126 kv.first != attr::software_pipeline_async_stages) {
1127 preserved_annotations.Set(key, kv.second);
1128 }
1129 }
1130
1131 for (size_t i = 0; i < pipeline_stages.size(); i++) {
1132 int stage = static_cast<int>(pipeline_stages[i]->value);
1133 bool is_async = pipeline_async_stages.find(stage) != pipeline_async_stages.end();
1134 PipelineAnnotation stage_order{stage,
1135 /*order=*/static_cast<int>(pipeline_orders[i]->value),
1136 is_async};
1137 pipeline_info.emplace(original_order[i], stage_order);
1138 }
1139
1140 ValidatePipelineBody(pipeline_info, original_order);
1141
1142 // Step 4: Rewrite the pipeline body.
1143 Stmt pipeline = PipelineRewriter::Rewrite(
1144 buffer_data_to_buffer_, double_buffers, pipeline_allocs, GetRef<For>(op), pipeline_info,
1145 fragment_info_, preserved_annotations, merge_async_commit_queue_scope_);
1146
1147 if (const auto* realize = op->body.as<BlockRealizeNode>()) {
1148 const auto& block = realize->block;
1149 for (const auto& buffer : block->alloc_buffers) {
1150 buffer_data_to_buffer_.erase(buffer->data);
1151 }
1152 }
1153 return pipeline;
1154 }
1155
1156 /*!
1157 * \brief Add buffer allocations to a block and update the write region of the block.
1158 * \param n The block pointer to which the buffer allocations are added.
1159 * \param alloc_buffers The buffer allocations to be added.
1160 */
1161 void AddAllocBuffers(BlockNode* n, const Array<Buffer> alloc_buffers) {
1162 for (const Buffer& alloc_buffer : alloc_buffers) {
1163 n->alloc_buffers.push_back(alloc_buffer);
1164 Region region;
1165 region.reserve(alloc_buffer->shape.size());
1166 for (const PrimExpr& dim : alloc_buffer->shape) {
1167 region.push_back(Range::FromMinExtent(0, dim));
1168 }
1169 n->writes.push_back(BufferRegion(alloc_buffer, region));
1170 }
1171 }
1172
1173 Stmt VisitStmt_(const BlockNode* op) final {
1174 for (const auto& buffer : op->alloc_buffers) {
1175 buffer_data_to_buffer_.Set(buffer->data, buffer);
1176 }
1177
1178 auto it = op->annotations.find(attr::double_buffer_scope);
1179 if (it != op->annotations.end()) {
1180 int buffer_index = Downcast<Integer>((*it).second).IntValue();
1181 CHECK(buffer_index >= 0 && static_cast<size_t>(buffer_index) < op->writes.size())
1182 << "ValueError: Index of the buffer exceeds the size of the write regions of the block. ("
1183 << buffer_index << " vs. " << op->writes.size() << ")";
1184 double_buffers.insert(op->writes[buffer_index]->buffer);
1185 }
1186 Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
1187
1188 for (const auto& buffer : op->alloc_buffers) {
1189 buffer_data_to_buffer_.erase(buffer->data);
1190 }
1191 return std::move(block);
1192 }
1193
1194 bool HasPipelineAnnotation(const ForNode* op) const {
1195 auto it1 = op->annotations.find(attr::software_pipeline_stage);
1196 auto it2 = op->annotations.find(attr::software_pipeline_order);
1197 bool has_stage = it1 != op->annotations.end();
1198 bool has_order = it2 != op->annotations.end();
1199 if (has_stage && has_order) {
1200 return true;
1201 }
1202 if (has_stage) {
1203 LOG(FATAL) << "ValueError: Order of the software pipeline is not defined.";
1204 }
1205 if (has_order) {
1206 LOG(FATAL) << "ValueError: Stage of the software pipeline is not defined.";
1207 }
1208 return false;
1209 }
1210
1211 Map<Var, Buffer> buffer_data_to_buffer_;
1212 std::unordered_map<const VarNode*, FragmentInfo> fragment_info_;
1213 std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> double_buffers;
1214 bool merge_async_commit_queue_scope_ = true;
1215};
1216
1217} // namespace software_pipeline
1218
1219namespace transform {
1220
1221/*!
1222 * \brief Transform annotated loops into pipelined one that parallelize producers and consumers.
1223 * \return The IR transform pass.
1224 */
1225Pass InjectSoftwarePipeline() {
1226 auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
1227 auto* fptr = f.CopyOnWrite();
1228 bool merge_async_commit_queue_scope =
1229 ctx->GetConfig<Bool>("tir.merge_async_commit_queue_scope", Bool(true)).value();
1230 fptr->body = software_pipeline::PipelineInjector::Inject(f, merge_async_commit_queue_scope);
1231 fptr->body = ConvertSSA(std::move(fptr->body));
1232 return f;
1233 };
1234 return CreatePrimFuncPass(pass_func, 0, "tir.InjectSoftwarePipeline", {});
1235}
1236
1237TVM_REGISTER_GLOBAL("tir.transform.InjectSoftwarePipeline").set_body_typed(InjectSoftwarePipeline);
1238
1239} // namespace transform
1240
1241} // namespace tir
1242} // namespace tvm
1243