1#include <ir_utils.h>
2#include <kernel_ir.h>
3#include <lower2device.h>
4
5#include <lower_double_buffer.h>
6
7namespace torch {
8namespace jit {
9namespace fuser {
10namespace cuda {
11
12unsigned int getDoubleBufferAxisPosition(const TensorView* tv) {
13 // Double-buffering prefetches the next subregion of the tensor by
14 // doubling the allocation. The subregion is defined by the axes
15 // at the CA position till the inner-most position. There must be
16 // at least one axis that is outside (left) of the CA position,
17 // which defines the loop where prefetching is applied. Therefore,
18 // the CA position must be larger than 0.
19
20 TORCH_INTERNAL_ASSERT(tv->getComputeAtPosition() > 0);
21
22 // Unroll must not exist outside of double-buffer axis
23 auto first_unroll_it = std::find_if(
24 tv->domain()->domain().begin(),
25 tv->domain()->domain().end(),
26 [](const auto axis) {
27 return axis->getParallelType() == ParallelType::Unroll;
28 });
29
30 const int first_unroll_pos =
31 std::distance(tv->domain()->domain().begin(), first_unroll_it);
32
33 const int unroll_or_ca_pos =
34 std::min((int)tv->getComputeAtPosition(), first_unroll_pos);
35
36 TORCH_INTERNAL_ASSERT(
37 unroll_or_ca_pos > 0,
38 "Invalid tensor to double-buffer. Valid double buffer axis not found due to Unroll. ",
39 tv->toString());
40
41 int valid_pos = -1;
42 // Skip parallelized or broadcast axes
43 for (int i = unroll_or_ca_pos - 1; i >= 0; --i) {
44 auto pt = tv->axis(i)->getParallelType();
45 if (!isParallelTypeThread(pt) && !tv->axis(i)->isBroadcast()) {
46 valid_pos = i;
47 break;
48 }
49 }
50
51 TORCH_INTERNAL_ASSERT(
52 valid_pos >= 0,
53 "Invalid tensor to double-buffer. Valid double buffer axis not found. ",
54 tv->toString());
55
56 return valid_pos;
57}
58
59IterDomain* getDoubleBufferAxis(const TensorView* tv) {
60 return tv->axis((int)getDoubleBufferAxisPosition(tv));
61}
62
63void validateDoubleBufferedTensor(const TensorView* tv) {
64 auto double_buffer_pos = getDoubleBufferAxisPosition(tv);
65
66 // Like vectorization, only UnaryOp::Set with another TensorView is
67 // considered.
68 auto def = tv->definition();
69 TORCH_INTERNAL_ASSERT(
70 (def->isA<UnaryOp>() &&
71 def->as<UnaryOp>()->getUnaryOpType() == UnaryOpType::Set) ||
72 // Load store op should generally support double buffering.
73 def->isA<LoadStoreOp>(),
74 "Invalid tensor to double-buffer. Only tensor defined by UnaryOp::Set is supported: ",
75 def->toString());
76
77 TORCH_INTERNAL_ASSERT(
78 def->input(0)->isA<TensorView>(),
79 "Invalid tensor to double-buffer. Only tensor defined by UnaryOp::Set with TensorView is supported: ",
80 def->toString());
81
82 // Require the producer tensor to have been computed entirely for
83 // the double-buffering loop. Otherwise, the producer itself would
84 // also need to be double-bufferred.
85 auto producer = def->input(0)->as<TensorView>();
86 TORCH_INTERNAL_ASSERT(
87 producer->getComputeAtPosition() <= double_buffer_pos,
88 "Invalid tensor to double-buffer. The computeAt position of the producer tensor must be moved left: ",
89 producer->toString());
90
91 // Not strictly necessary, but only gmem -> smem or local and smem -> local
92 // are allowed.
93 const auto p_mem_type = producer->getMemoryType();
94 const auto c_mem_type = tv->getMemoryType();
95 TORCH_INTERNAL_ASSERT(
96 (p_mem_type == MemoryType::Global &&
97 (c_mem_type == MemoryType::Shared || c_mem_type == MemoryType::Local)) ||
98 (p_mem_type == MemoryType::Shared && c_mem_type == MemoryType::Local),
99 "Invalid tensor to double-buffer: ",
100 tv->toString(),
101 ". Producer memory type: ",
102 p_mem_type,
103 ". Consumer memory type: ",
104 c_mem_type);
105
106 return;
107}
108
109namespace {
110
111// Initial inspection of a fusion to find and validate double buffered tensors
112class DoubleBufferFusionInspector : private IterVisitor {
113 public:
114 DoubleBufferFusionInspector(Fusion* fusion, DoubleBufferInfo& db_info)
115 : db_info_(db_info) {
116 traverse(fusion);
117 }
118
119 private:
120 using IterVisitor::handle;
121
122 void handle(TensorView* tv) final {
123 if (!(tv->isDoubleBuffered() || tv->isCircularBuffered())) {
124 return;
125 }
126
127 TORCH_INTERNAL_ASSERT(
128 tv->definition(), "Fusion input shouldn't be double buffered.", tv);
129
130 validateDoubleBufferedTensor(tv);
131
132 auto db_axis = getDoubleBufferAxis(tv);
133
134 db_info_.setDoubleBufferAxis(tv, db_axis);
135 }
136
137 private:
138 DoubleBufferInfo& db_info_;
139};
140
141// The epilogue loop is only created when the producer of a double
142// buffer tensor is on smem, in which case it would otherwise require
143// an additional predicate to guard buffer overruns. When it's on
144// gmem, that isn't the case, so it does not need to create an
145// epilogue loop.
146bool requireEpilogue(const std::vector<Expr*>& exprs) {
147 return std::any_of(exprs.begin(), exprs.end(), [](const Expr* expr) {
148 return expr->input(0)->as<TensorView>()->getMemoryType() ==
149 MemoryType::Shared;
150 });
151}
152
153// Replicates double buffer loops for Prologue, Main, and
154// Epilogue. Prologue only copies the load expressions of double
155// buffered tensors, whereas Epilogue does any expression other than
156// the loads. Main copies everything.
157class DoubleBufferLoopCloner : public kir::IrVisitor {
158 public:
159 static kir::ForLoop* clone(
160 kir::ForLoop* double_buffer_loop,
161 const std::vector<Expr*>& double_buffer_load_exprs,
162 DoubleBufferLoopStage loop_type) {
163 DoubleBufferLoopCloner cloner(
164 double_buffer_loop, double_buffer_load_exprs, loop_type);
165 cloner.clone();
166 return cloner.cloned_top_level_loop_;
167 }
168
169 private:
170 DoubleBufferLoopCloner(
171 kir::ForLoop* double_buffer_loop,
172 const std::vector<Expr*>& double_buffer_load_exprs,
173 DoubleBufferLoopStage loop_type)
174 : double_buffer_loop_(double_buffer_loop),
175 double_buffer_load_exprs_(double_buffer_load_exprs),
176 loop_type_(loop_type) {}
177
178 using kir::IrVisitor::handle;
179
180 void clone() {
181 const auto gpu_lower = GpuLower::current();
182
183 // Cloning the double buffer loop as follows:
184 //
185 // Prologue: 0 to 1
186 // Main: 0 to (extent-1)
187 // Epilogue: (extent-1) to extent
188
189 auto index = GpuLower::current()->caMap()->getIndexVariable(
190 double_buffer_loop_->iter_domain(), loop_type_);
191 auto start = double_buffer_loop_->start();
192 auto stop = double_buffer_loop_->stop();
193 auto stage_depth = gpu_lower->doubleBufferInfo().getStageDepthFor(
194 double_buffer_loop_->iter_domain());
195
196 if (loop_type_ == DoubleBufferLoopStage::Prolog) {
197 TORCH_INTERNAL_ASSERT(start->isZeroInt());
198 stop = SimplifyingIrBuilder::create<Int>(stage_depth - 1);
199 } else if (
200 loop_type_ == DoubleBufferLoopStage::Main &&
201 requireEpilogue(double_buffer_load_exprs_)) {
202 stop = IrBuilder::subExpr(
203 double_buffer_loop_->stop(), gpu_lower->kernel()->oneVal());
204 } else if (loop_type_ == DoubleBufferLoopStage::Epilog) {
205 TORCH_INTERNAL_ASSERT(requireEpilogue(double_buffer_load_exprs_));
206 start = IrBuilder::subExpr(
207 double_buffer_loop_->stop(),
208 SimplifyingIrBuilder::create<Int>(stage_depth - 1));
209 }
210
211 cloned_top_level_loop_ = IrBuilder::create<kir::ForLoop>(
212 double_buffer_loop_->iter_domain(),
213 index,
214 start,
215 stop,
216 gpu_lower->kernel()->oneVal(),
217 false,
218 nullptr,
219 double_buffer_loop_->isUnrollRequired(),
220 loop_type_);
221
222 handle(double_buffer_loop_);
223
224 if (stage_depth > 2) {
225 cloned_top_level_loop_->body().push_back(
226 IrBuilder::create<kir::CpAsyncCommit>());
227 }
228 }
229
230 void handle(kir::ForLoop* fl) final {
231 kir::ForLoop* cloned_loop = fl == double_buffer_loop_
232 ? cloned_top_level_loop_
233 : IrBuilder::create<kir::ForLoop>(fl);
234
235 cloned_scopes_.push_back(&cloned_loop->body());
236
237 kir::IrVisitor::handle(fl);
238
239 cloned_scopes_.pop_back();
240
241 // Add the cloned loop into the parent loop body only when the
242 // cloned loop contains expressions.
243 if (!cloned_loop->body().empty() && !cloned_scopes_.empty()) {
244 cloned_scopes_.back()->push_back(cloned_loop);
245 }
246 }
247
248 void handle(kir::IfThenElse* ite) final {
249 TORCH_INTERNAL_ASSERT(false, "No IfThenElse should exist yet");
250 }
251
252 void handle(Expr* expr) final {
253 if (expr->isA<kir::ForLoop>() || expr->isA<kir::IfThenElse>()) {
254 kir::IrVisitor::handle(expr);
255 return;
256 }
257
258 TORCH_INTERNAL_ASSERT(!cloned_scopes_.empty());
259
260 if (loop_type_ == DoubleBufferLoopStage::Main) {
261 cloned_scopes_.back()->push_back(expr);
262 return;
263 }
264
265 // In Prologue and Epilogue, either load expressions or anything
266 // else are copied. Note that there can be multiple exprs defining
267 // double buffered TVs (e.g., buffer initialization).
268
269 auto out_tv = ir_utils::getTvOutput(expr);
270 const auto is_double_buffer_load_expr = std::any_of(
271 double_buffer_load_exprs_.begin(),
272 double_buffer_load_exprs_.end(),
273 [out_tv](const auto load_expr) {
274 auto double_buffer_tv = ir_utils::getTvOutput(load_expr);
275 TORCH_INTERNAL_ASSERT(double_buffer_tv != nullptr);
276 return out_tv == double_buffer_tv;
277 });
278 if ((loop_type_ == DoubleBufferLoopStage::Prolog &&
279 is_double_buffer_load_expr) ||
280 (loop_type_ == DoubleBufferLoopStage::Epilog &&
281 !is_double_buffer_load_expr)) {
282 cloned_scopes_.back()->push_back(expr);
283 }
284 }
285
286 private:
287 kir::ForLoop* double_buffer_loop_ = nullptr;
288 const std::vector<Expr*>& double_buffer_load_exprs_;
289 const DoubleBufferLoopStage loop_type_;
290
291 kir::ForLoop* cloned_top_level_loop_ = nullptr;
292 std::deque<kir::Scope*> cloned_scopes_;
293};
294
295using InsertionInfo = std::unordered_map<kir::ForLoop*, std::vector<Expr*>>;
296
297// Traverse lowered loop-nests and find all double buffer loops and
298// associated load expressions.
299class DoubleBufferLoopNestInspector : private kir::IrVisitor {
300 public:
301 static InsertionInfo run(const std::vector<Expr*>& exprs) {
302 DoubleBufferLoopNestInspector inspector(exprs);
303 return inspector.insertion_info_;
304 }
305
306 private:
307 DoubleBufferLoopNestInspector(const std::vector<Expr*>& exprs) {
308 handle(exprs);
309 }
310
311 using kir::IrVisitor::handle;
312
313 // Collect double buffer related information on a expr
314 // that is a memory load, i.e. a LoadStore or a Set.
315 void handlePossibleLoadExpr(Expr* expr) {
316 const auto gpu_lower = GpuLower::current();
317
318 auto out_tv = ir_utils::getTvOutput(expr);
319
320 if (out_tv == nullptr) {
321 return;
322 }
323
324 // Ignore init loop
325 if (!(out_tv->isDoubleBuffered() || out_tv->isCircularBuffered()) ||
326 !expr->input(0)->isA<TensorView>()) {
327 return;
328 }
329
330 auto double_buffer_loop =
331 gpu_lower->doubleBufferInfo().getDoubleBufferLoop(out_tv, for_loops_);
332
333 TORCH_INTERNAL_ASSERT(
334 double_buffer_loop != nullptr,
335 "No double buffer loop found for a double buffered tensor: ",
336 out_tv->toString());
337
338 validateDoubleBufferLoop(double_buffer_loop);
339
340 insertion_info_[double_buffer_loop].push_back(expr);
341 }
342
343 void handle(UnaryOp* uop) final {
344 handlePossibleLoadExpr(uop);
345 }
346
347 void handle(LoadStoreOp* ldst) final {
348 handlePossibleLoadExpr(ldst);
349 }
350
351 static void validateDoubleBufferLoop(kir::ForLoop* loop) {
352 TORCH_INTERNAL_ASSERT(
353 loop->start()->isZeroInt(), "Unsupported loop: ", loop->toString());
354 TORCH_INTERNAL_ASSERT(
355 loop->step()->isOneInt(), "Unsupported loop: ", loop->toString());
356 TORCH_INTERNAL_ASSERT(
357 !loop->vectorize(),
358 "Vectorized loop should not be the allocation loop for double-buffered tensor: ",
359 loop->toString());
360 TORCH_INTERNAL_ASSERT(
361 !loop->vectorize_shift(),
362 "Vectorize shift loop should not be the allocation loop for double-buffered tensor: ",
363 loop->toString());
364 }
365
366 InsertionInfo insertion_info_;
367};
368
369// Apply double buffering transformations
370class DoubleBufferInserter : private kir::ExprMutator {
371 public:
372 // When there exist multiple double buffer loops, apply
373 // transformations to inner-most loops first. A single ExprMutator
374 // pass can only process one loop.
375 static std::vector<Expr*> run(
376 const std::vector<Expr*>& exprs,
377 InsertionInfo insertion_info) {
378 auto inserted_exprs = exprs;
379 while (!insertion_info.empty()) {
380 DoubleBufferInserter inserter(inserted_exprs, insertion_info);
381 inserted_exprs = inserter.exprs_;
382 }
383 return inserted_exprs;
384 }
385
386 private:
387 DoubleBufferInserter(
388 const std::vector<Expr*>& exprs,
389 InsertionInfo& insertion_info)
390 : insertion_info_(insertion_info) {
391 auto num_double_buffer_loops = insertion_info.size();
392 traverseAndInsert(exprs);
393 TORCH_INTERNAL_ASSERT(processed_loop_ != nullptr);
394 TORCH_INTERNAL_ASSERT(insertion_info.size() == num_double_buffer_loops - 1);
395 }
396
397 using kir::ExprMutator::handle;
398
399 void handle(kir::ForLoop* loop) final {
400 kir::ExprMutator::handle(loop);
401
402 // If another loop is already taken care of, no more loop should
403 // be done in the same pass
404 if (processed_loop_ != nullptr) {
405 return;
406 }
407
408 auto it = insertion_info_.find(loop);
409 if (it == insertion_info_.end()) {
410 return;
411 }
412
413 insert(loop, it->second);
414 processed_loop_ = loop;
415 insertion_info_.erase(loop);
416 }
417
418 void insert(
419 kir::ForLoop* double_buffer_loop,
420 const std::vector<Expr*>& loads) {
421 auto prologue_loop = DoubleBufferLoopCloner::clone(
422 double_buffer_loop, loads, DoubleBufferLoopStage::Prolog);
423 registerInsertBefore(double_buffer_loop, prologue_loop);
424
425 auto write_to_smem =
426 std::any_of(loads.begin(), loads.end(), [](const Expr* expr) {
427 return expr->output(0)->as<TensorView>()->getMemoryType() ==
428 MemoryType::Shared;
429 });
430
431 // RAW sync is not inserted for double buffered tensors. The only
432 // exception is the prologue load.
433 bool insert_cpasync_wait = false;
434 if (write_to_smem) {
435 // Here the initial sync before entering double buffer loop is
436 // inserted.
437
438 // If any of the double buffered tensor in this double buffer
439 // loop is async copy. We want to wait for the gmem loads to
440 // finish before synchronizing the block.
441 if (std::any_of(loads.begin(), loads.end(), ir_utils::isCpAsyncOp)) {
442 auto stage_depth =
443 GpuLower::current()->doubleBufferInfo().getStageDepthFor(
444 double_buffer_loop->iter_domain());
445 auto cp_async_wait =
446 IrBuilder::create<kir::CpAsyncWait>(stage_depth - 2);
447 registerInsertBefore(double_buffer_loop, cp_async_wait);
448 insert_cpasync_wait = true;
449 }
450
451 // Insert the initial block sync before entering main loop.
452 if (std::any_of(loads.begin(), loads.end(), [](Expr* expr) {
453 return GpuLower::current()
454 ->syncMap()
455 .needsRawSync(ir_utils::getTvOutput(expr))
456 .hasTID();
457 })) {
458 // If any of the double buffered loads require sync, as indicated
459 // by sync info map, insert the sync before entering the double buffer
460 // loop.
461 // TODO:
462 // Currently not supporting double buffer in gmem, but short to mid
463 // term not yet a priority to go for this case.
464 auto sync = IrBuilder::create<kir::BlockSync>(false);
465 registerInsertBefore(double_buffer_loop, sync);
466 }
467 }
468
469 auto main_loop = DoubleBufferLoopCloner::clone(
470 double_buffer_loop, loads, DoubleBufferLoopStage::Main);
471
472 registerReplace(double_buffer_loop, main_loop);
473
474 // Insert the wait instruction in this pass instead
475 // of relying on WAR sync pass to do it.
476 // The WAR sync pass today would insert the wait function
477 // exactly where we need it but the purpose of this wait
478 // insertion isn't exactly WAR protection.
479 //
480 // TODO: [Double Buffer Sync]
481 // We might eventually want to move the block sync inserted
482 // by WAR pass here as well since this sync insertion is kind
483 // of both WAR and RAW (or neither RAW nor WAR, depends
484 // on how we look at it).
485 // Eg. in the case when a intermediate
486 // tensor is double buffered.
487 //
488 // __block_sync(); // This is the initial sync
489 // For i in ... // Double buffer loop
490 // A[i%2] = ...;
491 // ... = A[1-i%2];
492 // __block_sync(); // sync within loop
493 // ...
494 // The "sync within loop" can be placed anywhere in the
495 // double buffer loop while in the case of RAW and WAR
496 // there'd be extra insertion point restrictions.
497 // We are currently not actively exploring opportunities
498 // with this property of "double buffer sync" so this
499 // is more conceptual at the moment, aka low priority.
500 if (insert_cpasync_wait) {
501 insertCpAsyncWaitInMainLoop(main_loop);
502 }
503
504 if (requireEpilogue(loads)) {
505 auto epilogue_loop = DoubleBufferLoopCloner::clone(
506 double_buffer_loop, loads, DoubleBufferLoopStage::Epilog);
507 registerInsertAfter(double_buffer_loop, epilogue_loop);
508 }
509 }
510
511 // Simple conservative rule for inserting async copy wait
512 // primitive in the double buffer loop:
513 void insertCpAsyncWaitInMainLoop(kir::ForLoop* main_loop) {
514 TORCH_INTERNAL_ASSERT(
515 !main_loop->body().empty(),
516 "Double buffer sync insertion: empty main loop.");
517 // Note: This pass explicitly assumes that WAR sync has been
518 // inserted so would need to be updated if we re-order the
519 // passes. Cleanups suggested in [Double Buffer Sync]
520 // would resolve this dependency on pass ordering.
521 auto end_of_loop_expr = main_loop->body().exprs().back();
522 auto stage_depth = GpuLower::current()->doubleBufferInfo().getStageDepthFor(
523 main_loop->iter_domain());
524 auto cp_async_wait = IrBuilder::create<kir::CpAsyncWait>(stage_depth - 2);
525
526 // Check if a sync has been inserted by WAR sync pass.
527 auto block_sync_it = std::find_if(
528 main_loop->body().exprs().rbegin(),
529 main_loop->body().exprs().rend(),
530 [](const Expr* expr) { return expr->isA<kir::BlockSync>(); });
531 if (block_sync_it == main_loop->body().exprs().rend()) {
532 // If there's no sync, i.e. no tensor needs cross
533 // thread communication. We still need a wait but
534 // it can just be anywhere in the loop. Chose to
535 // place at the end arbitrarily.
536 main_loop->body().insert_after(end_of_loop_expr, cp_async_wait);
537 } else {
538 // If a sync has been inserted, wait needs to be placed
539 // before the sync.
540 main_loop->body().insert_before(*block_sync_it, cp_async_wait);
541 }
542 }
543
544 private:
545 InsertionInfo& insertion_info_;
546 kir::ForLoop* processed_loop_ = nullptr;
547};
548
549} // namespace
550
551void DoubleBufferInfo::build(Fusion* fusion) {
552 DoubleBufferFusionInspector inspector(fusion, *this);
553
554 // Build double buffered loop id's
555 for (auto& info : map_) {
556 auto double_buffer_axis = info.second.double_buffer_axis;
557 // Keeps track of which loop disjoint set has been
558 // double buffered. In index allocation, one index
559 // variable would need to be allocated in each
560 // double buffer stage.
561 concrete_double_buffered_loop_id_.insert(
562 GpuLower::current()->caMap()->getConcreteMappedID(
563 double_buffer_axis, IdMappingMode::LOOP));
564 }
565}
566
567bool DoubleBufferInfo::isDoubleBufferedIterDomain(IterDomain* id) {
568 auto concrete_loop_id = GpuLower::current()->caMap()->getConcreteMappedID(
569 id, IdMappingMode::LOOP);
570 return concrete_double_buffered_loop_id_.count(concrete_loop_id);
571}
572
573DoubleBufferInfo::TvInfo& DoubleBufferInfo::getTvInfo(const TensorView* tv) {
574 TORCH_INTERNAL_ASSERT(
575 tv->isDoubleBuffered() || tv->isCircularBuffered(),
576 "Not a double-buffered tensor: ",
577 tv->toString());
578 return map_[tv];
579}
580
581void DoubleBufferInfo::setDoubleBufferAxis(
582 const TensorView* tv,
583 IterDomain* axis) {
584 getTvInfo(tv).double_buffer_axis = axis;
585
586 // Also validate the stage consistency with CA map.
587 unsigned int stage_depth = 0;
588 if (tv->isCircularBuffered()) {
589 stage_depth = tv->circularBufferDepth();
590 } else {
591 // Double buffer is essentially
592 // circular buffer with depth 2.
593 stage_depth = 2;
594 }
595
596 // Set and validate the new stage depth.
597 setStageDepth(axis, stage_depth);
598}
599
600void DoubleBufferInfo::setStageDepth(IterDomain* id, unsigned int stage_depth) {
601 auto concrete_loop_id = GpuLower::current()->caMap()->getConcreteMappedID(
602 id, IdMappingMode::LOOP);
603
604 auto maybe_exisiting_depth_it = stage_depth_.find(concrete_loop_id);
605 if (maybe_exisiting_depth_it == stage_depth_.end()) {
606 stage_depth_[concrete_loop_id] = stage_depth;
607 } else {
608 TORCH_INTERNAL_ASSERT(
609 stage_depth == maybe_exisiting_depth_it->second,
610 "Unsupported multiple depth pipelining, was set to ",
611 maybe_exisiting_depth_it->second,
612 " by ",
613 maybe_exisiting_depth_it->first->toString(),
614 " and then set to ",
615 stage_depth,
616 " by ",
617 concrete_loop_id->toString());
618 }
619}
620
621IterDomain* DoubleBufferInfo::getDoubleBufferAxis(const TensorView* tv) {
622 if (!(tv->isDoubleBuffered() || tv->isCircularBuffered())) {
623 return nullptr;
624 }
625
626 return getTvInfo(tv).double_buffer_axis;
627}
628
629unsigned int DoubleBufferInfo::getStageDepthFor(
630 IterDomain* double_buffer_axis) {
631 auto concrete_id = GpuLower::current()->caMap()->getConcreteMappedID(
632 double_buffer_axis, IdMappingMode::LOOP);
633
634 auto maybe_depth_it = stage_depth_.find(concrete_id);
635
636 TORCH_INTERNAL_ASSERT(
637 maybe_depth_it != stage_depth_.end(), "Stage depth not found");
638
639 return maybe_depth_it->second;
640}
641
642kir::ForLoop* DoubleBufferInfo::getDoubleBufferLoop(
643 IterDomain* axis,
644 const std::vector<kir::ForLoop*>& loops,
645 bool ignore_prologue) {
646 auto loop_it = std::find_if(loops.begin(), loops.end(), [&](const auto loop) {
647 return GpuLower::current()->caMap()->areMapped(
648 loop->iter_domain(), axis, IdMappingMode::EXACT) &&
649 (!ignore_prologue ||
650 loop->doubleBufferLoopStage() != DoubleBufferLoopStage::Prolog);
651 });
652
653 if (loop_it != loops.end()) {
654 return *loop_it;
655 } else {
656 return nullptr;
657 }
658}
659
660kir::ForLoop* DoubleBufferInfo::getDoubleBufferLoop(
661 const TensorView* tv,
662 const std::vector<kir::ForLoop*>& loops,
663 bool ignore_prologue) {
664 auto axis = getDoubleBufferAxis(tv);
665
666 if (axis == nullptr) {
667 return nullptr;
668 }
669
670 return getDoubleBufferLoop(axis, loops, ignore_prologue);
671}
672
673void DoubleBufferInfo::setOriginalAllocSize(
674 const TensorView* tv,
675 Val* original_alloc_size) {
676 getTvInfo(tv).original_alloc_size = original_alloc_size;
677}
678
679Val* DoubleBufferInfo::getOriginalAllocSize(const TensorView* tv) {
680 if (!(tv->isDoubleBuffered() || tv->isCircularBuffered())) {
681 return nullptr;
682 }
683
684 return getTvInfo(tv).original_alloc_size;
685}
686
687std::vector<Expr*> DoubleBufferPass::run(const std::vector<Expr*>& exprs) {
688 auto insertion_info = DoubleBufferLoopNestInspector::run(exprs);
689 return DoubleBufferInserter::run(exprs, insertion_info);
690}
691
692} // namespace cuda
693} // namespace fuser
694} // namespace jit
695} // namespace torch
696