1#include <lower_alias_memory.h>
2
3#include <instrumentation.h>
4#include <ir_iostream.h>
5#include <ir_utils.h>
6#include <kernel_expr_evaluator.h>
7#include <kernel_ir.h>
8#include <lower2device.h>
9#include <lower_utils.h>
10
11#include <sstream>
12#include <unordered_map>
13#include <unordered_set>
14
15namespace torch {
16namespace jit {
17namespace fuser {
18namespace cuda {
19
20namespace {
21// Alias used for std::transform
22IterDomain* exactConcreteId(IterDomain* id) {
23 return GpuLower::current()->caMap()->getConcreteMappedID(
24 id, IdMappingMode::EXACT);
25}
26
27//! Checks that the current loop nest is realizing a serial
28//! broadcast so that each index of producer buffer can be visited
29//! multiple times, in which case the aggressive is not valid.
30bool isSerialBroadcastResolution(TensorView* producer, TensorView* consumer) {
31 //! Note: see issue #1785:
32 //! serial broadcast resolution doesn't only happen to
33 //! immediate outputs of broadcast ops. We can also have
34 //! example:
35 //! T1[I,B] = broadcast(T0[I]])
36 //! T3[I,I] = T1[I,B] + T2[I,I]
37 //! T4[I,I] = T3[I,I]
38 //! and generates the following loop:
39 //! alloc T0[4]
40 //! For i in 0..3
41 //! T0[...] =
42 //!
43 //! For j in 0...X:
44 //! alloc T3[4]
45 //! for k in 0..3:
46 //! alloc T1[1]
47 //! T1[0] = T0[k] // <- This is actually a broadcast resolution
48 //! T3[k] = T1[0] + T2[...]
49 //! T4[...] = T3[...]
50 //!
51 //! In this case we are actually visiting each pixel of T0 in each iteration
52 //! of the j loop while T1 was the broadcasted tensor causing this reuse.
53 //!
54 //! The current version of checking covers this scenario by checking the root
55 //! ids of the consumer concrete loop id's. Any time a local tensor like T0
56 //! appears in a re-use scenario like above, we should see a serial loop id
57 //! that was derived from some root id that doesn't concretely map to T0's
58 //! domain.
59
60 // Serial concrete loop id's that cover consumer's iter domain.
61 std::vector<Val*> consumer_serial_loop_concrete_ids;
62
63 for (auto consumer_leaf_id : consumer->domain()->domain()) {
64 auto concrete_loop_id = GpuLower::current()->caMap()->getConcreteMappedID(
65 consumer_leaf_id, IdMappingMode::LOOP);
66
67 // Check for any serial loop id with non-trivial extent
68 if (!concrete_loop_id->isThread() &&
69 !concrete_loop_id->extent()->isOneInt()) {
70 consumer_serial_loop_concrete_ids.push_back(concrete_loop_id);
71 }
72 }
73
74 // Collect the root id's that the serial loop iterdomain
75 // are transformed from.
76 auto serial_loop_roots = InputsOf::outputs(
77 FusionGuard::getCurFusion(), consumer_serial_loop_concrete_ids);
78
79 // Collect exact concrete id's in producer's root domain
80 std::unordered_set<IterDomain*> producer_exact_concrete_root_ids;
81 auto producer_root =
82 TensorDomain::noReductions(producer->getMaybeRFactorDomain());
83 std::transform(
84 producer_root.begin(),
85 producer_root.end(),
86 std::inserter(
87 producer_exact_concrete_root_ids,
88 producer_exact_concrete_root_ids.begin()),
89 exactConcreteId);
90
91 // Check if serial loop roots indexes any exact root id's that
92 // is not within the set of producer's root exact id's. These
93 // id's will imply that the same producer pixel is accessed
94 // in multiple iterations of the materialized serial loop.
95 for (auto serial_loop_root :
96 ir_utils::filterByType<IterDomain>(serial_loop_roots)) {
97 if (!producer_exact_concrete_root_ids.count(
98 GpuLower::current()->caMap()->getConcreteMappedID(
99 serial_loop_root, IdMappingMode::EXACT))) {
100 return true;
101 }
102 }
103
104 return false;
105}
106
107//! Get string representation of Allocate size for symbolic comparison
108//!
109//! TODO: Some expr simplifications could also be helpful
110class SymbolicSizePrinter : private OptOutConstDispatch {
111 public:
112 static std::string printSize(const kir::Allocate* allocate) {
113 SymbolicSizePrinter printer;
114 printer.handle(allocate->size());
115 return printer.os_.str();
116 }
117
118 private:
119 using OptOutConstDispatch::handle;
120
121 void handle(const Int* node) final {
122 if (auto def = node->definition()) {
123 OptOutConstDispatch::handle(def);
124 } else if (node->isConst()) {
125 os_ << *node->value();
126 } else {
127 os_ << "ki" << node->name();
128 }
129 }
130
131 void handle(const NamedScalar* named_scalar) final {
132 os_ << "@" << named_scalar->name();
133 }
134
135 void handle(const UnaryOp* unary_op) final {
136 os_ << unary_op->getUnaryOpType() << "(";
137 OptOutConstDispatch::handle(unary_op);
138 os_ << ")";
139 }
140
141 void handle(const BinaryOp* binary_op) final {
142 os_ << binary_op->getBinaryOpType() << "(";
143 OptOutConstDispatch::handle(binary_op->lhs());
144 os_ << ",";
145 OptOutConstDispatch::handle(binary_op->rhs());
146 os_ << ")";
147 }
148
149 private:
150 std::stringstream os_;
151};
152
153class BufferUseDefInfo;
154//! A debug printer internal to this pass to support
155//! future expansion and inline annotation of pass info.
156class BufferReuseDebugPrinter {
157 enum class DebugLineType { EXPR, START_BLOCK, END_BLOCK };
158
159 struct ExprInfo {
160 int lineno = 0;
161 DebugLineType line_type = DebugLineType::EXPR;
162 };
163
164 using DebugEntry = std::pair<ExprInfo, Expr*>;
165 using DebugEntryPtr = std::unique_ptr<DebugEntry>;
166
167 public:
168 BufferReuseDebugPrinter() : ir_printer_(os_){};
169
170 std::string dumpDebugInfo() {
171 os_.clear();
172 for (auto& debug_entry : debug_info_) {
173 switch (debug_entry->first.line_type) {
174 case DebugLineType::START_BLOCK:
175 startBlock();
176 break;
177 case DebugLineType::END_BLOCK:
178 endBlock();
179 break;
180 case DebugLineType::EXPR:
181 os_ << debug_entry->first.lineno;
182 handle(debug_entry->second);
183 break;
184 default:
185 TORCH_INTERNAL_ASSERT(false, "unreachable");
186 }
187 }
188 os_ << "\n\n";
189 return os_.str();
190 }
191
192 private:
193 friend class BufferUseDefInfo;
194
195 void pushBack(int lineno, Expr* expr) {
196 makeExprEntry(lineno, expr);
197 }
198
199 void pushScope() {
200 makeScopeEntry(DebugLineType::START_BLOCK);
201 }
202
203 void popScope() {
204 makeScopeEntry(DebugLineType::END_BLOCK);
205 }
206
207 void makeExprEntry(int lineno, Expr* expr) {
208 auto debug_entry_ptr = std::make_unique<DebugEntry>();
209 debug_entry_ptr->first.lineno = lineno;
210 debug_entry_ptr->second = expr;
211 debug_info_.emplace_back(std::move(debug_entry_ptr));
212 }
213
214 void makeScopeEntry(DebugLineType line_type) {
215 TORCH_INTERNAL_ASSERT(
216 line_type == DebugLineType::END_BLOCK ||
217 line_type == DebugLineType::START_BLOCK);
218 auto debug_entry_ptr = std::make_unique<DebugEntry>();
219 debug_entry_ptr->first.line_type = line_type;
220 debug_entry_ptr->second = nullptr;
221 debug_info_.emplace_back(std::move(debug_entry_ptr));
222 }
223
224 void handle(const Expr* node) {
225 if (auto for_loop = dynamic_cast<const kir::ForLoop*>(node)) {
226 handle(for_loop);
227 } else if (auto ite = dynamic_cast<const kir::IfThenElse*>(node)) {
228 handle(ite);
229 } else {
230 indent();
231 ir_printer_.handle(node);
232 }
233 if (auto alloc = dynamic_cast<const kir::Allocate*>(node)) {
234 printAllocInfo(alloc);
235 }
236 }
237
238 void handle(const kir::ForLoop* node) {
239 indent();
240 os_ << "FOR ";
241 ir_printer_.handle(node->index());
242 os_ << " in ";
243 ir_printer_.handle(node->iter_domain());
244 os_ << ":\n";
245 }
246
247 void handle(const kir::IfThenElse* node) {
248 // This pass doesn't yet need to handle
249 // ite but could fill in the blank here
250 // if this printer can be used for
251 // other passes or we have more
252 // complex ite pattern.
253 TORCH_INTERNAL_ASSERT(false, "unsupported");
254 }
255
256 void printAllocInfo(const kir::Allocate* alloc);
257
258 std::stringstream& indent() {
259 for (const auto i : c10::irange(indent_level_)) {
260 (void)i; // Suppress unused variable warning
261 os_ << " ";
262 }
263 return os_;
264 }
265
266 void startBlock() {
267 indent_level_++;
268 }
269
270 void endBlock() {
271 indent_level_--;
272 }
273
274 private:
275 std::stringstream os_;
276 IrPrinter ir_printer_;
277 int indent_level_ = 0;
278
279 std::vector<DebugEntryPtr> debug_info_;
280 BufferUseDefInfo* buffer_info_ = nullptr;
281};
282
283//! Utility class for modeling the liveness interval.
284//! The first write and last read
285//! is based on the position on the linear order within
286//! the Kernel IR.
287//! The interval is semi-open,
288//! i.e. [First_Write, Last_Read)
289//! So the buffer is NOT available at exactly First_Write
290//! position while it IS available at Last_Read.
291class BufferLiveInterval {
292 public:
293 // Simple detection of intersection of two intervals
294 bool intersect(BufferLiveInterval* other) {
295 if (first_write_pos_ <= other->first_write_pos_) {
296 return other->first_write_pos_ < last_read_pos_;
297 } else {
298 return first_write_pos_ < other->last_read_pos_;
299 }
300 }
301
302 void markWrite(int pos) {
303 if (first_write_pos_ == -1) {
304 first_write_pos_ = pos;
305 }
306 }
307
308 void markRead(int pos) {
309 last_read_pos_ = pos;
310 TORCH_INTERNAL_ASSERT(
311 first_write_pos_ > 0,
312 "lower_alias_memory: a read seen before any write")
313 TORCH_INTERNAL_ASSERT(
314 pos > first_write_pos_,
315 "lower_alias_memory: marking a read before write");
316 all_read_pos_.push_back(pos);
317 }
318
319 const auto& allReads() {
320 return all_read_pos_;
321 }
322
323 auto firstWrite() const {
324 return first_write_pos_;
325 }
326
327 auto lastRead() const {
328 return last_read_pos_;
329 }
330
331 std::string toString() {
332 std::stringstream ss;
333 ss << "[ " << first_write_pos_ << " , " << last_read_pos_ << " )";
334 return ss.str();
335 }
336
337 private:
338 int first_write_pos_ = -1;
339 int last_read_pos_ = -1;
340 std::vector<int> all_read_pos_;
341};
342
343using BufferLiveIntervalPtrList = std::vector<BufferLiveInterval*>;
344
345//! Thin struct to keep track of loops. The actual loop body is
346//! considered live in [start_pos, end_pos)
347struct ScopeInfo {
348 int start_pos = -1;
349 int end_pos = -1;
350
351 // nullptr means it's global scope
352 kir::ForLoop* loop = nullptr;
353};
354
355using ScopeInfoOwningPtr = std::unique_ptr<ScopeInfo>;
356using ScopeInfoOwningPtrList = std::vector<ScopeInfoOwningPtr>;
357
358//! Utility class to record the read and write of each
359//! allocated buffer.
360//!
361//! Note:
362//! this simplified interval analysis only works on pointwise ops and
363//! reductions and broadcast. With no non-trivial IfThenElse and no
364//! non-trivial re-computation.
365//!
366//! Will probably at some point need dataflow and index analysis to precisely
367//! handle loop carried dependency.
368struct AllocationUseDefInfo {
369 kir::Allocate* alloc_expr = nullptr;
370 kir::Allocate* alias_to = nullptr;
371 bool is_inner_alias = false;
372 bool should_try_alias = true;
373 MemoryType mem_type = MemoryType::Local;
374 DataType data_type = DataType::Float;
375 std::string size_expr;
376 ScopeInfo* loop_info = nullptr;
377 bool can_use_inner_alias = true;
378 int alloc_pos = -1;
379 std::unique_ptr<std::vector<AllocationUseDefInfo*>> inner_alias_list_ =
380 nullptr;
381 std::unique_ptr<BufferLiveInterval> inner_live_interval = nullptr;
382 std::unique_ptr<BufferLiveIntervalPtrList> inner_subscribed_intevals =
383 nullptr;
384 std::unique_ptr<BufferLiveInterval> outer_live_interval = nullptr;
385 std::unique_ptr<BufferLiveIntervalPtrList> outer_subscribed_intevals =
386 nullptr;
387};
388
389using AllocationInfoOwningPtr = std::unique_ptr<AllocationUseDefInfo>;
390using AllocationInfoOwningList = std::vector<AllocationInfoOwningPtr>;
391using AllocationInfoPtr = AllocationUseDefInfo*;
392using AllocationInfoList = std::vector<AllocationInfoPtr>;
393
394//! Analysis pass to collect the liveness info of local and shared buffers:
395//! The liveness info is illustrated as follows:
396//!
397//! For Idx0 ...
398//! Alloc(T1, register)
399//! Alloc(T2, register)
400//! Alloc(T3, register)
401//!
402//! For Idx1 ... <---------- Outer Live Interval of T1 begin
403//! For Idx2 ...
404//! T1 = ... <-- Inner Live Interval of T1 begin
405//! T2 = ...
406//! T3 = T1 + ... <-- Inner Live Interval of T1 end
407//! T5 = T3 + ...
408//! EndFor Idx2
409//! EndFor Idx1 <------- Outer Live Interval of T1 end
410//!
411//! Alloc(T4, register)
412//! For Idx3 ...
413//! T4 = ...
414//! EndFor Idx3
415//! EndFor Idx0
416//!
417//! Each buffer is associated with an `inner_live_interval` and an
418//! `outer_live_interval`,
419//! Inner interval marks the exprs that are the first write and last read of
420//! the buffer.
421//! Outer interval marks the begining of the loop of first write and end of
422//! the loop of last read, both at the same loop level as the buffer
423//! allocation.
424class BufferUseDefInfo {
425 public:
426 // Alias local memory if it exceeds this threshold
427 static constexpr long kRegisterSizeThreshold = 1;
428
429 BufferUseDefInfo(
430 const std::vector<Expr*>& exprs,
431 BufferReuseDebugPrinter* debug_printer = nullptr)
432 : debug_printer_(debug_printer) {
433 if (debug_printer) {
434 debug_printer->buffer_info_ = this;
435 }
436 collectScopeInfo(exprs);
437 collectScopeUseDefInfo(exprs);
438 }
439
440 //! Returns live interval info of buffer if previously
441 //! computed.
442 c10::optional<AllocationInfoPtr> getMaybeReuseInfoFor(
443 kir::Allocate* allocate) const {
444 auto alloc_it = map_allocate_to_info_.find(allocate);
445 if (alloc_it == map_allocate_to_info_.end()) {
446 return c10::nullopt;
447 }
448 auto alloc = alloc_it->second;
449 return alloc;
450 }
451
452 //! Realize alias of two buffers through inner alias analysis and
453 //! keep track of the re-use.
454 void useInnerAlias(AllocationInfoPtr from, AllocationInfoPtr to) {
455 to->inner_alias_list_->push_back(from);
456 to->inner_subscribed_intevals->push_back(from->inner_live_interval.get());
457 setAlias(from, to);
458 from->is_inner_alias = true;
459 }
460
461 //! Realize alias of two buffers through outer alias analysis and
462 //! keep track of the re-use.
463 void useOuterAlias(AllocationInfoPtr from, AllocationInfoPtr to) {
464 to->outer_subscribed_intevals->push_back(from->outer_live_interval.get());
465 setAlias(from, to);
466 }
467
468 //! To run before performing in-place sharing analysis.
469 //! Initializes the inner live intervals with each
470 //! allocation's inner live interval.
471 void prepareInnerSharingAnalysis() {
472 for (auto it : map_allocate_to_info_) {
473 auto alloc_info = it.second;
474 // At beginning only use interval for each
475 // allocate is their corresponding live interval
476 alloc_info->inner_subscribed_intevals->push_back(
477 alloc_info->inner_live_interval.get());
478 }
479 }
480
481 //! To run before performing outer interval based sharing analysis.
482 //! Initializes the outer live intervals with the outer live interval
483 //! of each allocation and copy inner sharing information.
484 void prepareOuterSharingAnalysis() {
485 for (auto it : map_allocate_to_info_) {
486 auto alloc_info = it.second;
487 if (!alias_map_.count(alloc_info)) {
488 alloc_info->outer_subscribed_intevals->push_back(
489 alloc_info->outer_live_interval.get());
490 // Update only if this buffer isn't an alias
491 for (auto inner_alias : *(alloc_info->inner_alias_list_)) {
492 alloc_info->outer_subscribed_intevals->push_back(
493 inner_alias->outer_live_interval.get());
494 }
495 }
496 }
497 }
498
499 private:
500 void handle(Expr* expr) {
501 current_pos_++;
502 if (debug_printer_) {
503 debug_printer_->pushBack(current_pos_, expr);
504 }
505 if (auto alloc = dynamic_cast<kir::Allocate*>(expr)) {
506 handle(alloc);
507 } else if (auto for_loop = dynamic_cast<kir::ForLoop*>(expr)) {
508 handle(for_loop);
509 } else if (auto ite = dynamic_cast<kir::IfThenElse*>(expr)) {
510 handle(ite);
511 } else {
512 collectLivenessInfo(expr);
513 }
514 }
515
516 void handleScope(const std::vector<Expr*>& exprs) {
517 if (debug_printer_) {
518 debug_printer_->pushScope();
519 }
520 for (auto expr : exprs) {
521 handle(expr);
522 }
523 if (debug_printer_) {
524 debug_printer_->popScope();
525 }
526 }
527
528 void handle(kir::ForLoop* for_loop) {
529 auto loop_info = map_loop_pos_to_loop_info_.at(current_pos_);
530 current_stack_.push_back(loop_info);
531 handleScope(for_loop->body().exprs());
532 current_stack_.pop_back();
533 }
534
535 void handle(kir::IfThenElse* ite) {
536 TORCH_INTERNAL_ASSERT(
537 false, "lower_alias_memory: no support for IfThenElse at this phase.");
538 }
539
540 // Generate allocation info for allocation after some pre-filtering
541 // conditions.
542 void handle(kir::Allocate* alloc) {
543 if (alloc->alias()) {
544 // We shouldn't really see a case like this in general, but
545 // some Fusion outputs could have been aliased to inputs.
546 // It should be safe to ignore these in the use-def analysis.
547 return;
548 }
549
550 auto tv = dynamic_cast<TensorView*>(alloc->buffer());
551 if (!tv) {
552 return;
553 }
554
555 // Collect the allocate info data
556
557 // Collect memory type, skip global buffers
558 auto mem_type = tv->getMemoryType();
559 if (mem_type != MemoryType::Local && mem_type != MemoryType::Shared) {
560 return;
561 }
562
563 // Skip smaller register sizes
564 bool should_try_alias = true;
565 if (mem_type == MemoryType::Local) {
566 const auto register_size = expr_evaluator_.evaluate(alloc->size());
567 if (!register_size.has_value()) {
568 TORCH_WARN_ONCE(
569 "Lower_alias_memory : dynamic sized register allocation");
570 return;
571 }
572 if (register_size->as<int64_t>() <= kRegisterSizeThreshold) {
573 should_try_alias = false;
574 }
575 }
576
577 auto data_type = tv->dtype();
578 auto size_print = SymbolicSizePrinter::printSize(alloc);
579
580 // Make sure we don't have conflicting information on record
581 TORCH_INTERNAL_ASSERT(!map_allocate_to_info_.count(alloc));
582 TORCH_INTERNAL_ASSERT(!map_tv_to_allocations_.count(tv->name()));
583
584 // make AllocationUseDefInfo:
585 auto alloc_info = makeUseDefInfo();
586 alloc_info->alloc_expr = alloc;
587 alloc_info->mem_type = mem_type;
588 alloc_info->data_type = data_type;
589 alloc_info->size_expr = size_print;
590 alloc_info->loop_info = current_stack_.back();
591 alloc_info->should_try_alias = should_try_alias;
592
593 // record short cuts
594 map_allocate_to_info_[alloc] = alloc_info;
595 map_tv_to_allocations_[tv->name()] = alloc_info;
596 }
597
598 void collectScopeUseDefInfo(const std::vector<Expr*>& exprs) {
599 // Reset position pointer
600 resetExprCounter();
601 TORCH_INTERNAL_ASSERT(global_scope_info_ != nullptr);
602 current_stack_.push_back(global_scope_info_);
603 handleScope(exprs);
604 }
605
606 void collectScopeInfo(const std::vector<Expr*>& exprs) {
607 // Reset position pointer
608 resetExprCounter();
609 collectScopeInfoWithinLoop(exprs, nullptr);
610 }
611
612 void collectScopeInfoWithinLoop(
613 const std::vector<Expr*>& exprs,
614 kir::ForLoop* current_loop) {
615 auto loop_info = makeScopeInfo(current_loop);
616 for (auto expr : exprs) {
617 current_pos_++;
618 if (auto for_loop = dynamic_cast<kir::ForLoop*>(expr)) {
619 collectScopeInfoWithinLoop(for_loop->body().exprs(), for_loop);
620 }
621 }
622 loop_info->end_pos = current_pos_ + 1;
623 }
624
625 void resetExprCounter() {
626 current_pos_ = -1;
627 }
628
629 // Iterate over the inputs and outputs of exprs and update
630 // the liveness info of local buffers if applicaable.
631 void collectLivenessInfo(const Expr* expr) {
632 if (!ir_utils::isTvOp(expr)) {
633 return;
634 }
635
636 auto out_tv = expr->outputs()[0]->as<TensorView>();
637
638 // Collect all tv's that resolves broadcast in this
639 // expr. The current analysis isn't enough to capture
640 // their liveness range.
641 for (auto input_tv : ir_utils::filterByType<TensorView>(expr->inputs())) {
642 auto maybe_alloc_info = getMaybeAllocInfoFromTV(input_tv);
643 if (maybe_alloc_info.has_value()) {
644 if (!isSerialBroadcastResolution(input_tv, out_tv)) {
645 maybe_alloc_info.value()->inner_live_interval->markRead(current_pos_);
646 } else {
647 // Disable inner alias info for this buffer, since line number based
648 // analysis is no longer precise enough for inplace sharing
649 // if a serial broadcast is realized.
650 maybe_alloc_info.value()->can_use_inner_alias = false;
651 }
652
653 auto outer_loop_info =
654 ascendLoopNestToSameLevelAs(maybe_alloc_info.value());
655
656 if (outer_loop_info) {
657 maybe_alloc_info.value()->outer_live_interval->markRead(
658 outer_loop_info->end_pos);
659 } else {
660 // Allocate is inlined in the innermost loop,
661 // so outer live interval is the same as inner.
662 maybe_alloc_info.value()->outer_live_interval->markRead(current_pos_);
663 }
664 }
665 }
666 for (auto output_tv : ir_utils::filterByType<TensorView>(expr->outputs())) {
667 auto maybe_alloc_info = getMaybeAllocInfoFromTV(output_tv);
668 if (maybe_alloc_info.has_value()) {
669 maybe_alloc_info.value()->inner_live_interval->markWrite(current_pos_);
670 auto outer_loop_info =
671 ascendLoopNestToSameLevelAs(maybe_alloc_info.value());
672 if (outer_loop_info) {
673 maybe_alloc_info.value()->outer_live_interval->markWrite(
674 outer_loop_info->start_pos);
675 } else {
676 maybe_alloc_info.value()->outer_live_interval->markWrite(
677 current_pos_);
678 }
679 }
680 }
681 }
682
683 //! Find the loop level of expr that apears in the same scope as
684 //! the reference allocate. Eg.
685 //!
686 //! For ...
687 //! For ...
688 //! Allocate <---- reference arg
689 //! For ..
690 //! For ...
691 //! For ... <---- this function returns `ScopeInfo` for this loop
692 //! For ...
693 //! expr <---- current expr (implied in current_stack_ and
694 //! current_pos_ )
695 //! Assumes that expr either writes to or reads from the reference allocate.
696 ScopeInfo* ascendLoopNestToSameLevelAs(AllocationUseDefInfo* reference) {
697 auto allocate_loop_info = reference->loop_info;
698 if (allocate_loop_info->loop == nullptr) {
699 if (current_stack_.size() > 1) {
700 return current_stack_[1];
701 }
702 return nullptr;
703 }
704
705 for (const auto idx : c10::irange(current_stack_.size() - 1)) {
706 if (current_stack_[idx] == allocate_loop_info) {
707 return current_stack_[idx + 1];
708 }
709 }
710
711 TORCH_INTERNAL_ASSERT(
712 current_stack_.back() == allocate_loop_info,
713 "lower_alias_memory : expr outer loop inconsistent with allocate");
714
715 // Returning a nullptr means the allocate is in the current stack frame.
716 return nullptr;
717 }
718
719 c10::optional<AllocationInfoPtr> getMaybeAllocInfoFromTV(TensorView* tv) {
720 auto alloc_it = map_tv_to_allocations_.find(tv->name());
721 if (alloc_it == map_tv_to_allocations_.end()) {
722 return c10::nullopt;
723 }
724 return alloc_it->second;
725 }
726
727 //! Factory function for internal loop information data
728 ScopeInfo* makeScopeInfo(kir::ForLoop* loop) {
729 auto loop_info_ptr = std::make_unique<ScopeInfo>();
730 auto loop_info = loop_info_ptr.get();
731 loop_info->start_pos = current_pos_;
732 loop_info->end_pos = -1;
733 loop_info->loop = loop;
734 all_loop_infos_.emplace_back(std::move(loop_info_ptr));
735
736 if (loop == nullptr) {
737 TORCH_INTERNAL_ASSERT(
738 !global_scope_info_, "Should only create global scope info once!");
739 global_scope_info_ = loop_info;
740 } else {
741 map_loop_pos_to_loop_info_[current_pos_] = loop_info;
742 }
743 return loop_info;
744 }
745
746 //! Factory function for internal use-def information data
747 AllocationUseDefInfo* makeUseDefInfo() {
748 auto alloc_info_ptr = std::make_unique<AllocationUseDefInfo>();
749 auto alloc_info = alloc_info_ptr.get();
750
751 alloc_info->alloc_pos = current_pos_;
752 alloc_info->inner_alias_list_ =
753 std::make_unique<std::vector<AllocationUseDefInfo*>>();
754 alloc_info->inner_live_interval = std::make_unique<BufferLiveInterval>();
755 alloc_info->inner_subscribed_intevals =
756 std::make_unique<BufferLiveIntervalPtrList>();
757 alloc_info->outer_live_interval = std::make_unique<BufferLiveInterval>();
758 alloc_info->outer_subscribed_intevals =
759 std::make_unique<BufferLiveIntervalPtrList>();
760 all_allocations_.emplace_back(std::move(alloc_info_ptr));
761 return alloc_info;
762 }
763
764 // Realize buffer alias and keep track of the alias info.
765 void setAlias(AllocationInfoPtr from, AllocationInfoPtr to) {
766 alias_map_[from] = to;
767 from->alloc_expr->setAlias(to->alloc_expr);
768 from->alias_to = to->alloc_expr;
769 }
770
771 private:
772 friend BufferReuseDebugPrinter;
773 friend class SerialBroadcastIntervalExpansion;
774
775 //! Allocation sites that will participate in this analysis
776 std::unordered_map<const kir::Allocate*, AllocationInfoPtr>
777 map_allocate_to_info_;
778
779 //! Map TensorView name to Allocate node.
780 //! Note: this assumes that each tensor view is only allocated once.
781 std::unordered_map<StmtNameType, AllocationInfoPtr> map_tv_to_allocations_;
782
783 //! Keeps track of all the allocations that have been set to alias
784 std::unordered_map<AllocationInfoPtr, AllocationInfoPtr> alias_map_;
785
786 //! Keep track of stack:
787 std::vector<ScopeInfo*> current_stack_;
788
789 //! Contains start and end position of the global scope
790 ScopeInfo* global_scope_info_ = nullptr;
791
792 //! map loop start position to loop info
793 std::unordered_map<int, ScopeInfo*> map_loop_pos_to_loop_info_;
794
795 //! Owning list of collected allocation info
796 AllocationInfoOwningList all_allocations_;
797
798 //! Owning list of collected allocation info
799 ScopeInfoOwningPtrList all_loop_infos_;
800
801 //! Expression Evaluator to infer size of register allocation
802 kir::ExpressionEvaluator expr_evaluator_;
803
804 //! Position counter when iterating through the exprs list
805 int current_pos_ = -1;
806
807 //! Debug info:
808 BufferReuseDebugPrinter* debug_printer_ = nullptr;
809};
810
811void BufferReuseDebugPrinter::printAllocInfo(const kir::Allocate* alloc) {
812 TORCH_INTERNAL_ASSERT(buffer_info_ != nullptr);
813 std::string message_header(" \033[1;32m^^^^^ ---Buffer Reuse Info--- ");
814 std::string message_end(" \033[0m\n");
815 if (!buffer_info_->map_allocate_to_info_.count(alloc)) {
816 // This buffer is not considered for any sharing, either
817 // because of un-supported op or size below threshold.
818 return;
819 }
820
821 auto alloc_info = buffer_info_->map_allocate_to_info_.at(alloc);
822
823 indent() << message_header;
824 if (alloc_info->alias_to) {
825 if (alloc_info->is_inner_alias) {
826 os_ << "(inner) ";
827 } else {
828 os_ << "(outer) ";
829 }
830 os_ << " alias to alloc at pos "
831 << buffer_info_->getMaybeReuseInfoFor(alloc_info->alias_to)
832 .value()
833 ->alloc_pos
834 << " ";
835 } else {
836 os_ << " not aliased ";
837 }
838
839 os_ << " , ";
840
841 if (alloc_info->can_use_inner_alias) {
842 os_ << "inner live interval: ";
843 os_ << alloc_info->inner_live_interval->toString() << " , ";
844 }
845 os_ << "size expr : " << alloc_info->size_expr << " , "
846 << "outer live interval: " << alloc_info->outer_live_interval->toString();
847 indent() << message_end;
848}
849
850//! Reuse Allocation nodes via pointer aliasing
851class AllocateReuseModifier {
852 public:
853 static void modify(const std::vector<Expr*>& exprs) {
854 AllocateReuseModifier modifier(exprs);
855 }
856
857 static void debugPrint(const std::vector<Expr*>& exprs) {
858 BufferReuseDebugPrinter debug_printer;
859 AllocateReuseModifier modifier(exprs, &debug_printer);
860 std::cout << debug_printer.dumpDebugInfo();
861 }
862
863 private:
864 AllocateReuseModifier(
865 const std::vector<Expr*>& exprs,
866 BufferReuseDebugPrinter* debug_printer_ = nullptr)
867 : buffer_info_(exprs, debug_printer_) {
868 // Perform in-place sharing first and then outer liveness
869 // based sharing. Since outer liveness info can still
870 // be used with some buffers already aliasing through
871 // in-place re-use but wouldn't be the case if we did
872 // outer liveness based sharing first.
873 buffer_info_.prepareInnerSharingAnalysis();
874 handleScope(exprs);
875
876 inner_aliasing_pass_ = false;
877
878 buffer_info_.prepareOuterSharingAnalysis();
879 handleScope(exprs);
880 }
881
882 // Second visit of an allocate op
883 void handle(kir::Allocate* allocate) {
884 // Check that if this allocation site is one that
885 // we want to re-use or replace with an alias
886
887 auto maybe_alloc_info = buffer_info_.getMaybeReuseInfoFor(allocate);
888 if (maybe_alloc_info.has_value() &&
889 maybe_alloc_info.value()->alias_to == nullptr) {
890 // Try to re-use existing allocates
891 if (!tryReuseOtherAllocate(maybe_alloc_info.value())) {
892 // If didn't re-use, should register this
893 // allocate so that future allocates
894 // can re-use this one.
895 current_visible_buffer_stack_.back()->push_back(
896 maybe_alloc_info.value());
897 }
898 }
899 }
900
901 bool tryReuseOtherAllocate(AllocationInfoPtr alloc_info) {
902 if (!alloc_info->should_try_alias) {
903 return false;
904 }
905 if (!alloc_info->inner_alias_list_->empty()) {
906 // Avoid 2-hop aliasing for simplicity. Can support if really need in
907 // extreme cases.
908 return false;
909 }
910
911 // Move backwards on list of re-usable allocates on the stack, prefer
912 // reusing nearest allocation
913 for (auto reuse_stack_it = current_visible_buffer_stack_.rbegin();
914 reuse_stack_it != current_visible_buffer_stack_.rend();
915 reuse_stack_it++) {
916 for (auto alloc_to_reuse_it = (*reuse_stack_it)->rbegin();
917 alloc_to_reuse_it != (*reuse_stack_it)->rend();
918 alloc_to_reuse_it++) {
919 auto alloc_to_reuse = *alloc_to_reuse_it;
920
921 // Check if this re-use candidate is an alias
922 if (alloc_to_reuse->alias_to != nullptr) {
923 continue;
924 }
925
926 // Check if this alloc has the same mem type
927 if (alloc_info->mem_type != alloc_to_reuse->mem_type) {
928 continue;
929 }
930
931 // Check if this alloc has the same size
932 if (alloc_info->size_expr != alloc_to_reuse->size_expr) {
933 continue;
934 }
935
936 // Check if this alloc has the same data type
937 if (alloc_info->data_type != alloc_to_reuse->data_type) {
938 continue;
939 }
940
941 // Check if live intervals have any overlap
942 auto subscribed_intervals = inner_aliasing_pass_
943 ? alloc_to_reuse->inner_subscribed_intevals.get()
944 : alloc_to_reuse->outer_subscribed_intevals.get();
945
946 auto alloc_live_interval = inner_aliasing_pass_
947 ? alloc_info->inner_live_interval.get()
948 : alloc_info->outer_live_interval.get();
949
950 if (std::any_of(
951 subscribed_intervals->begin(),
952 subscribed_intervals->end(),
953 [alloc_live_interval](auto subscribed_interval) {
954 return alloc_live_interval->intersect(subscribed_interval);
955 })) {
956 continue;
957 }
958
959 // Special checks for inner sharing pass
960 if (inner_aliasing_pass_ &&
961 !isValidInnerSharing(alloc_to_reuse, alloc_info)) {
962 continue;
963 }
964
965 if (alloc_info->alloc_expr->buffer()->isA<TensorView>()) {
966 if (!alloc_info->alloc_expr->buffer()->isA<TensorView>()) {
967 continue;
968 }
969 auto this_tv = alloc_info->alloc_expr->buffer()->as<TensorView>();
970 auto reuse_tv = alloc_info->alloc_expr->buffer()->as<TensorView>();
971 // Check that either both tv's are vectorized acceses, or neither are.
972 // Vectorized allocations require correct alignment so they can only
973 // alias with other allocations with the right alignment
974 const auto& va = GpuLower::current()->vectorizedAccesses();
975 if ((va.find(this_tv) == va.end()) !=
976 (va.find(reuse_tv) == va.end())) {
977 return false;
978 }
979
980 // Shared memory is all aligned to 128 bits, local memory might not be
981 if (this_tv->getMemoryType() == MemoryType::Local &&
982 va.find(this_tv) != va.end()) {
983 // Make sure alignment matches
984 if (va.at(this_tv) != va.at(reuse_tv)) {
985 return false;
986 }
987 }
988 }
989
990 // TODO:
991 // Outer interval based sharing supports arbitrary re-indexing into
992 // the same buffer and would require additional syncs if fully
993 // enabled.
994 // Need a few more checks to insert syncs if necessary before turning
995 // on this sharing.
996 if (!inner_aliasing_pass_ &&
997 alloc_info->mem_type == MemoryType::Shared) {
998 continue;
999 }
1000
1001 // Now re-use the alloc here and be sure to update
1002 reUseAllocation(alloc_info, alloc_to_reuse);
1003 return true;
1004 }
1005 }
1006 return false;
1007 }
1008
1009 void handle(Expr* expr) {
1010 if (auto ite = dynamic_cast<kir::IfThenElse*>(expr)) {
1011 handle(ite);
1012 } else if (auto for_loop = dynamic_cast<kir::ForLoop*>(expr)) {
1013 handle(for_loop);
1014 } else if (auto allocate = dynamic_cast<kir::Allocate*>(expr)) {
1015 handle(allocate);
1016 }
1017 }
1018
1019 void handle(const kir::ForLoop* for_loop) {
1020 handleScope(for_loop->body().exprs());
1021 }
1022
1023 void handle(const kir::IfThenElse* for_loop) {
1024 TORCH_INTERNAL_ASSERT(
1025 false,
1026 "lower_alias_memory: IfThenElse before unrolling is not yet supported");
1027 }
1028
1029 void handleScope(const std::vector<Expr*>& exprs) {
1030 current_visible_buffer_stack_.emplace_back(
1031 std::make_unique<AllocationInfoList>());
1032 for (auto expr : exprs) {
1033 handle(expr);
1034 }
1035 current_visible_buffer_stack_.pop_back();
1036 }
1037
1038 struct InPlaceSharingInfo {
1039 bool has_broadcast_between = false;
1040 bool has_unsupported_op = false;
1041 };
1042
1043 //! Careful heavy check on inner sharing candidates,
1044 //! current enforced conditions are:
1045 //!
1046 //! 1. The two buffers have producer-consumer relationship
1047 //! 2. No halo in the allocated iter domains
1048 //! 3. Require index equivalence when sharing across broadcast
1049 bool isValidInnerSharing(
1050 AllocationUseDefInfo* alloc_info,
1051 AllocationUseDefInfo* to_reuse) {
1052 // Disable if either of the buffers do not support inner sharing
1053 if (!alloc_info->can_use_inner_alias || !to_reuse->can_use_inner_alias) {
1054 return false;
1055 }
1056 // Assume inputs are TV allocations, which should have been checked
1057 // before reaching this point.
1058 auto this_tv = alloc_info->alloc_expr->buffer()->as<TensorView>();
1059 auto reuse_tv = to_reuse->alloc_expr->buffer()->as<TensorView>();
1060
1061 // Aggressively disable inner sharing for swizzled tvs since
1062 // the indexing order is in general not tractable.
1063 // But outer sharing should still apply.
1064 if (this_tv->hasSwizzleOp() || reuse_tv->hasSwizzleOp()) {
1065 return false;
1066 }
1067
1068 // Check the values in between the two buffers.
1069 auto vals_between_this_and_reuse =
1070 DependencyCheck::getAllValsBetween({this_tv}, {reuse_tv});
1071 if (vals_between_this_and_reuse.empty()) {
1072 vals_between_this_and_reuse =
1073 DependencyCheck::getAllValsBetween({reuse_tv}, {this_tv});
1074 }
1075
1076 if (!vals_between_this_and_reuse.empty()) {
1077 // Temporarily disable sharing across difficult
1078 // ops for inner sharing and can be relaxed gradually.
1079 auto topo_info = checkOpsInBetween(vals_between_this_and_reuse);
1080
1081 // Avoid difficult and future introduced ops
1082 if (topo_info.has_unsupported_op) {
1083 return false;
1084 }
1085
1086 // Get information on the allocated domains of the
1087 // two buffers
1088 auto& local_alloc_map = GpuLower::current()->localAllocationInfoMap();
1089 auto alloc_it = local_alloc_map.find(alloc_info->alloc_expr);
1090 auto to_reuse_it = local_alloc_map.find(to_reuse->alloc_expr);
1091 if (alloc_it == local_alloc_map.end() ||
1092 to_reuse_it == local_alloc_map.end()) {
1093 return false;
1094 }
1095
1096 // Disable in-place reusing for halo ops, since halo
1097 // can issue pointwise op multiple points at some points.
1098 if (alloc_it->second->has_halo || to_reuse_it->second->has_halo) {
1099 return false;
1100 }
1101
1102 // Require matched iterdomains for sharing across broadcast
1103 if (topo_info.has_broadcast_between) {
1104 auto& alloc_domains = alloc_it->second->alloc_domains;
1105 auto& reuse_domains = to_reuse_it->second->alloc_domains;
1106
1107 return allocationDomainsIndexMapped(alloc_domains, reuse_domains);
1108 }
1109
1110 // If only pointwise and reduction ops in between and no broadcast
1111 // should be ok to re-use in place.
1112 return true;
1113 }
1114
1115 // this and reuse are not dependencies of each other,
1116 // which means we cannot use inner sharing.
1117 return false;
1118 }
1119
1120 InPlaceSharingInfo checkOpsInBetween(std::vector<Val*>& all_used_vals) {
1121 InPlaceSharingInfo info;
1122 std::unordered_set<Val*> all_used_val_set(
1123 all_used_vals.begin(), all_used_vals.end());
1124
1125 for (auto val : all_used_vals) {
1126 if (auto tv = dynamic_cast<TensorView*>(val)) {
1127 auto tv_def = tv->definition();
1128 if (!tv_def) {
1129 continue;
1130 }
1131 if (!isPointwiseTvOp(tv_def) && !ir_utils::isReductionTvOp(tv_def)) {
1132 if (isBroadcastTvOp(tv_def)) {
1133 info.has_broadcast_between = true;
1134 } else {
1135 info.has_unsupported_op = true;
1136 }
1137 }
1138 }
1139 }
1140 return info;
1141 }
1142
1143 bool allocationDomainsIndexMapped(
1144 std::vector<IterDomain*>& alloc_domains,
1145 std::vector<IterDomain*>& reuse_domains) {
1146 // Require that the allocated domains are exactly mapped.
1147 if (alloc_domains.size() != reuse_domains.size()) {
1148 return false;
1149 }
1150
1151 // Check index map for the corresponding axes.
1152 for (const auto id_it : c10::irange(alloc_domains.size())) {
1153 if (!GpuLower::current()->caMap()->areMapped(
1154 alloc_domains[id_it],
1155 reuse_domains[id_it],
1156 IdMappingMode::EXACT)) {
1157 return false;
1158 }
1159 }
1160 return true;
1161 }
1162
1163 void reUseAllocation(
1164 AllocationUseDefInfo* alloc_info,
1165 AllocationUseDefInfo* to_reuse) {
1166 // Update analysis result
1167 if (inner_aliasing_pass_) {
1168 buffer_info_.useInnerAlias(alloc_info, to_reuse);
1169 } else {
1170 buffer_info_.useOuterAlias(alloc_info, to_reuse);
1171 }
1172 }
1173
1174 // Do we have a true pointwise op?
1175 // (ie. a TV op, excluding direct assignments and reductions)
1176 bool isPointwiseTvOp(const Expr* expr) {
1177 if (ir_utils::isTvOp(expr)) {
1178 return expr->isA<UnaryOp>() || expr->isA<BinaryOp>() ||
1179 expr->isA<TernaryOp>();
1180 }
1181 return false;
1182 }
1183
1184 // Utility to capture reduction ops
1185 bool isBroadcastTvOp(const Expr* expr) {
1186 if (!ir_utils::isTvOp(expr)) {
1187 return false;
1188 }
1189 return expr->isA<BroadcastOp>();
1190 }
1191
1192 private:
1193 // Analysis result from the first pass collecting the use-defs
1194 BufferUseDefInfo buffer_info_;
1195
1196 // Internal data keeping track of currently visible allocations as
1197 // the pass iterate through the expr list, grouped by the stack
1198 // layer of alloc ops.
1199 std::vector<std::unique_ptr<AllocationInfoList>>
1200 current_visible_buffer_stack_;
1201
1202 // Marks state of current pass
1203 bool inner_aliasing_pass_ = true;
1204};
1205
1206} // namespace
1207
1208std::vector<Expr*> reuseMemoryAllocations(const std::vector<Expr*>& exprs) {
1209 FUSER_PERF_SCOPE("reuseMemoryAllocations");
1210 bool debug_print = isDebugDumpEnabled(DebugDumpOption::BufferReuseInfo);
1211 if (debug_print) {
1212 AllocateReuseModifier::debugPrint(exprs);
1213 }
1214 AllocateReuseModifier::modify(exprs);
1215 return exprs;
1216}
1217
1218} // namespace cuda
1219} // namespace fuser
1220} // namespace jit
1221} // namespace torch
1222