1 | #include <lower_alias_memory.h> |
2 | |
3 | #include <instrumentation.h> |
4 | #include <ir_iostream.h> |
5 | #include <ir_utils.h> |
6 | #include <kernel_expr_evaluator.h> |
7 | #include <kernel_ir.h> |
8 | #include <lower2device.h> |
9 | #include <lower_utils.h> |
10 | |
11 | #include <sstream> |
12 | #include <unordered_map> |
13 | #include <unordered_set> |
14 | |
15 | namespace torch { |
16 | namespace jit { |
17 | namespace fuser { |
18 | namespace cuda { |
19 | |
20 | namespace { |
21 | // Alias used for std::transform |
22 | IterDomain* exactConcreteId(IterDomain* id) { |
23 | return GpuLower::current()->caMap()->getConcreteMappedID( |
24 | id, IdMappingMode::EXACT); |
25 | } |
26 | |
27 | //! Checks that the current loop nest is realizing a serial |
28 | //! broadcast so that each index of producer buffer can be visited |
29 | //! multiple times, in which case the aggressive is not valid. |
30 | bool isSerialBroadcastResolution(TensorView* producer, TensorView* consumer) { |
31 | //! Note: see issue #1785: |
32 | //! serial broadcast resolution doesn't only happen to |
33 | //! immediate outputs of broadcast ops. We can also have |
34 | //! example: |
35 | //! T1[I,B] = broadcast(T0[I]]) |
36 | //! T3[I,I] = T1[I,B] + T2[I,I] |
37 | //! T4[I,I] = T3[I,I] |
38 | //! and generates the following loop: |
39 | //! alloc T0[4] |
40 | //! For i in 0..3 |
41 | //! T0[...] = |
42 | //! |
43 | //! For j in 0...X: |
44 | //! alloc T3[4] |
45 | //! for k in 0..3: |
46 | //! alloc T1[1] |
47 | //! T1[0] = T0[k] // <- This is actually a broadcast resolution |
48 | //! T3[k] = T1[0] + T2[...] |
49 | //! T4[...] = T3[...] |
50 | //! |
51 | //! In this case we are actually visiting each pixel of T0 in each iteration |
52 | //! of the j loop while T1 was the broadcasted tensor causing this reuse. |
53 | //! |
54 | //! The current version of checking covers this scenario by checking the root |
55 | //! ids of the consumer concrete loop id's. Any time a local tensor like T0 |
56 | //! appears in a re-use scenario like above, we should see a serial loop id |
57 | //! that was derived from some root id that doesn't concretely map to T0's |
58 | //! domain. |
59 | |
60 | // Serial concrete loop id's that cover consumer's iter domain. |
61 | std::vector<Val*> consumer_serial_loop_concrete_ids; |
62 | |
63 | for (auto consumer_leaf_id : consumer->domain()->domain()) { |
64 | auto concrete_loop_id = GpuLower::current()->caMap()->getConcreteMappedID( |
65 | consumer_leaf_id, IdMappingMode::LOOP); |
66 | |
67 | // Check for any serial loop id with non-trivial extent |
68 | if (!concrete_loop_id->isThread() && |
69 | !concrete_loop_id->extent()->isOneInt()) { |
70 | consumer_serial_loop_concrete_ids.push_back(concrete_loop_id); |
71 | } |
72 | } |
73 | |
74 | // Collect the root id's that the serial loop iterdomain |
75 | // are transformed from. |
76 | auto serial_loop_roots = InputsOf::outputs( |
77 | FusionGuard::getCurFusion(), consumer_serial_loop_concrete_ids); |
78 | |
79 | // Collect exact concrete id's in producer's root domain |
80 | std::unordered_set<IterDomain*> producer_exact_concrete_root_ids; |
81 | auto producer_root = |
82 | TensorDomain::noReductions(producer->getMaybeRFactorDomain()); |
83 | std::transform( |
84 | producer_root.begin(), |
85 | producer_root.end(), |
86 | std::inserter( |
87 | producer_exact_concrete_root_ids, |
88 | producer_exact_concrete_root_ids.begin()), |
89 | exactConcreteId); |
90 | |
91 | // Check if serial loop roots indexes any exact root id's that |
92 | // is not within the set of producer's root exact id's. These |
93 | // id's will imply that the same producer pixel is accessed |
94 | // in multiple iterations of the materialized serial loop. |
95 | for (auto serial_loop_root : |
96 | ir_utils::filterByType<IterDomain>(serial_loop_roots)) { |
97 | if (!producer_exact_concrete_root_ids.count( |
98 | GpuLower::current()->caMap()->getConcreteMappedID( |
99 | serial_loop_root, IdMappingMode::EXACT))) { |
100 | return true; |
101 | } |
102 | } |
103 | |
104 | return false; |
105 | } |
106 | |
107 | //! Get string representation of Allocate size for symbolic comparison |
108 | //! |
109 | //! TODO: Some expr simplifications could also be helpful |
110 | class SymbolicSizePrinter : private OptOutConstDispatch { |
111 | public: |
112 | static std::string printSize(const kir::Allocate* allocate) { |
113 | SymbolicSizePrinter printer; |
114 | printer.handle(allocate->size()); |
115 | return printer.os_.str(); |
116 | } |
117 | |
118 | private: |
119 | using OptOutConstDispatch::handle; |
120 | |
121 | void handle(const Int* node) final { |
122 | if (auto def = node->definition()) { |
123 | OptOutConstDispatch::handle(def); |
124 | } else if (node->isConst()) { |
125 | os_ << *node->value(); |
126 | } else { |
127 | os_ << "ki" << node->name(); |
128 | } |
129 | } |
130 | |
131 | void handle(const NamedScalar* named_scalar) final { |
132 | os_ << "@" << named_scalar->name(); |
133 | } |
134 | |
135 | void handle(const UnaryOp* unary_op) final { |
136 | os_ << unary_op->getUnaryOpType() << "(" ; |
137 | OptOutConstDispatch::handle(unary_op); |
138 | os_ << ")" ; |
139 | } |
140 | |
141 | void handle(const BinaryOp* binary_op) final { |
142 | os_ << binary_op->getBinaryOpType() << "(" ; |
143 | OptOutConstDispatch::handle(binary_op->lhs()); |
144 | os_ << "," ; |
145 | OptOutConstDispatch::handle(binary_op->rhs()); |
146 | os_ << ")" ; |
147 | } |
148 | |
149 | private: |
150 | std::stringstream os_; |
151 | }; |
152 | |
153 | class BufferUseDefInfo; |
154 | //! A debug printer internal to this pass to support |
155 | //! future expansion and inline annotation of pass info. |
156 | class BufferReuseDebugPrinter { |
157 | enum class DebugLineType { EXPR, START_BLOCK, END_BLOCK }; |
158 | |
159 | struct ExprInfo { |
160 | int lineno = 0; |
161 | DebugLineType line_type = DebugLineType::EXPR; |
162 | }; |
163 | |
164 | using DebugEntry = std::pair<ExprInfo, Expr*>; |
165 | using DebugEntryPtr = std::unique_ptr<DebugEntry>; |
166 | |
167 | public: |
168 | BufferReuseDebugPrinter() : ir_printer_(os_){}; |
169 | |
170 | std::string dumpDebugInfo() { |
171 | os_.clear(); |
172 | for (auto& debug_entry : debug_info_) { |
173 | switch (debug_entry->first.line_type) { |
174 | case DebugLineType::START_BLOCK: |
175 | startBlock(); |
176 | break; |
177 | case DebugLineType::END_BLOCK: |
178 | endBlock(); |
179 | break; |
180 | case DebugLineType::EXPR: |
181 | os_ << debug_entry->first.lineno; |
182 | handle(debug_entry->second); |
183 | break; |
184 | default: |
185 | TORCH_INTERNAL_ASSERT(false, "unreachable" ); |
186 | } |
187 | } |
188 | os_ << "\n\n" ; |
189 | return os_.str(); |
190 | } |
191 | |
192 | private: |
193 | friend class BufferUseDefInfo; |
194 | |
195 | void pushBack(int lineno, Expr* expr) { |
196 | makeExprEntry(lineno, expr); |
197 | } |
198 | |
199 | void pushScope() { |
200 | makeScopeEntry(DebugLineType::START_BLOCK); |
201 | } |
202 | |
203 | void popScope() { |
204 | makeScopeEntry(DebugLineType::END_BLOCK); |
205 | } |
206 | |
207 | void makeExprEntry(int lineno, Expr* expr) { |
208 | auto debug_entry_ptr = std::make_unique<DebugEntry>(); |
209 | debug_entry_ptr->first.lineno = lineno; |
210 | debug_entry_ptr->second = expr; |
211 | debug_info_.emplace_back(std::move(debug_entry_ptr)); |
212 | } |
213 | |
214 | void makeScopeEntry(DebugLineType line_type) { |
215 | TORCH_INTERNAL_ASSERT( |
216 | line_type == DebugLineType::END_BLOCK || |
217 | line_type == DebugLineType::START_BLOCK); |
218 | auto debug_entry_ptr = std::make_unique<DebugEntry>(); |
219 | debug_entry_ptr->first.line_type = line_type; |
220 | debug_entry_ptr->second = nullptr; |
221 | debug_info_.emplace_back(std::move(debug_entry_ptr)); |
222 | } |
223 | |
224 | void handle(const Expr* node) { |
225 | if (auto for_loop = dynamic_cast<const kir::ForLoop*>(node)) { |
226 | handle(for_loop); |
227 | } else if (auto ite = dynamic_cast<const kir::IfThenElse*>(node)) { |
228 | handle(ite); |
229 | } else { |
230 | indent(); |
231 | ir_printer_.handle(node); |
232 | } |
233 | if (auto alloc = dynamic_cast<const kir::Allocate*>(node)) { |
234 | printAllocInfo(alloc); |
235 | } |
236 | } |
237 | |
238 | void handle(const kir::ForLoop* node) { |
239 | indent(); |
240 | os_ << "FOR " ; |
241 | ir_printer_.handle(node->index()); |
242 | os_ << " in " ; |
243 | ir_printer_.handle(node->iter_domain()); |
244 | os_ << ":\n" ; |
245 | } |
246 | |
247 | void handle(const kir::IfThenElse* node) { |
248 | // This pass doesn't yet need to handle |
249 | // ite but could fill in the blank here |
250 | // if this printer can be used for |
251 | // other passes or we have more |
252 | // complex ite pattern. |
253 | TORCH_INTERNAL_ASSERT(false, "unsupported" ); |
254 | } |
255 | |
256 | void printAllocInfo(const kir::Allocate* alloc); |
257 | |
258 | std::stringstream& indent() { |
259 | for (const auto i : c10::irange(indent_level_)) { |
260 | (void)i; // Suppress unused variable warning |
261 | os_ << " " ; |
262 | } |
263 | return os_; |
264 | } |
265 | |
266 | void startBlock() { |
267 | indent_level_++; |
268 | } |
269 | |
270 | void endBlock() { |
271 | indent_level_--; |
272 | } |
273 | |
274 | private: |
275 | std::stringstream os_; |
276 | IrPrinter ir_printer_; |
277 | int indent_level_ = 0; |
278 | |
279 | std::vector<DebugEntryPtr> debug_info_; |
280 | BufferUseDefInfo* buffer_info_ = nullptr; |
281 | }; |
282 | |
283 | //! Utility class for modeling the liveness interval. |
284 | //! The first write and last read |
285 | //! is based on the position on the linear order within |
286 | //! the Kernel IR. |
287 | //! The interval is semi-open, |
288 | //! i.e. [First_Write, Last_Read) |
289 | //! So the buffer is NOT available at exactly First_Write |
290 | //! position while it IS available at Last_Read. |
291 | class BufferLiveInterval { |
292 | public: |
293 | // Simple detection of intersection of two intervals |
294 | bool intersect(BufferLiveInterval* other) { |
295 | if (first_write_pos_ <= other->first_write_pos_) { |
296 | return other->first_write_pos_ < last_read_pos_; |
297 | } else { |
298 | return first_write_pos_ < other->last_read_pos_; |
299 | } |
300 | } |
301 | |
302 | void markWrite(int pos) { |
303 | if (first_write_pos_ == -1) { |
304 | first_write_pos_ = pos; |
305 | } |
306 | } |
307 | |
308 | void markRead(int pos) { |
309 | last_read_pos_ = pos; |
310 | TORCH_INTERNAL_ASSERT( |
311 | first_write_pos_ > 0, |
312 | "lower_alias_memory: a read seen before any write" ) |
313 | TORCH_INTERNAL_ASSERT( |
314 | pos > first_write_pos_, |
315 | "lower_alias_memory: marking a read before write" ); |
316 | all_read_pos_.push_back(pos); |
317 | } |
318 | |
319 | const auto& allReads() { |
320 | return all_read_pos_; |
321 | } |
322 | |
323 | auto firstWrite() const { |
324 | return first_write_pos_; |
325 | } |
326 | |
327 | auto lastRead() const { |
328 | return last_read_pos_; |
329 | } |
330 | |
331 | std::string toString() { |
332 | std::stringstream ss; |
333 | ss << "[ " << first_write_pos_ << " , " << last_read_pos_ << " )" ; |
334 | return ss.str(); |
335 | } |
336 | |
337 | private: |
338 | int first_write_pos_ = -1; |
339 | int last_read_pos_ = -1; |
340 | std::vector<int> all_read_pos_; |
341 | }; |
342 | |
343 | using BufferLiveIntervalPtrList = std::vector<BufferLiveInterval*>; |
344 | |
345 | //! Thin struct to keep track of loops. The actual loop body is |
346 | //! considered live in [start_pos, end_pos) |
347 | struct ScopeInfo { |
348 | int start_pos = -1; |
349 | int end_pos = -1; |
350 | |
351 | // nullptr means it's global scope |
352 | kir::ForLoop* loop = nullptr; |
353 | }; |
354 | |
355 | using ScopeInfoOwningPtr = std::unique_ptr<ScopeInfo>; |
356 | using ScopeInfoOwningPtrList = std::vector<ScopeInfoOwningPtr>; |
357 | |
358 | //! Utility class to record the read and write of each |
359 | //! allocated buffer. |
360 | //! |
361 | //! Note: |
362 | //! this simplified interval analysis only works on pointwise ops and |
363 | //! reductions and broadcast. With no non-trivial IfThenElse and no |
364 | //! non-trivial re-computation. |
365 | //! |
366 | //! Will probably at some point need dataflow and index analysis to precisely |
367 | //! handle loop carried dependency. |
368 | struct AllocationUseDefInfo { |
369 | kir::Allocate* alloc_expr = nullptr; |
370 | kir::Allocate* alias_to = nullptr; |
371 | bool is_inner_alias = false; |
372 | bool should_try_alias = true; |
373 | MemoryType mem_type = MemoryType::Local; |
374 | DataType data_type = DataType::Float; |
375 | std::string size_expr; |
376 | ScopeInfo* loop_info = nullptr; |
377 | bool can_use_inner_alias = true; |
378 | int alloc_pos = -1; |
379 | std::unique_ptr<std::vector<AllocationUseDefInfo*>> inner_alias_list_ = |
380 | nullptr; |
381 | std::unique_ptr<BufferLiveInterval> inner_live_interval = nullptr; |
382 | std::unique_ptr<BufferLiveIntervalPtrList> inner_subscribed_intevals = |
383 | nullptr; |
384 | std::unique_ptr<BufferLiveInterval> outer_live_interval = nullptr; |
385 | std::unique_ptr<BufferLiveIntervalPtrList> outer_subscribed_intevals = |
386 | nullptr; |
387 | }; |
388 | |
389 | using AllocationInfoOwningPtr = std::unique_ptr<AllocationUseDefInfo>; |
390 | using AllocationInfoOwningList = std::vector<AllocationInfoOwningPtr>; |
391 | using AllocationInfoPtr = AllocationUseDefInfo*; |
392 | using AllocationInfoList = std::vector<AllocationInfoPtr>; |
393 | |
394 | //! Analysis pass to collect the liveness info of local and shared buffers: |
395 | //! The liveness info is illustrated as follows: |
396 | //! |
397 | //! For Idx0 ... |
398 | //! Alloc(T1, register) |
399 | //! Alloc(T2, register) |
400 | //! Alloc(T3, register) |
401 | //! |
402 | //! For Idx1 ... <---------- Outer Live Interval of T1 begin |
403 | //! For Idx2 ... |
404 | //! T1 = ... <-- Inner Live Interval of T1 begin |
405 | //! T2 = ... |
406 | //! T3 = T1 + ... <-- Inner Live Interval of T1 end |
407 | //! T5 = T3 + ... |
408 | //! EndFor Idx2 |
409 | //! EndFor Idx1 <------- Outer Live Interval of T1 end |
410 | //! |
411 | //! Alloc(T4, register) |
412 | //! For Idx3 ... |
413 | //! T4 = ... |
414 | //! EndFor Idx3 |
415 | //! EndFor Idx0 |
416 | //! |
417 | //! Each buffer is associated with an `inner_live_interval` and an |
418 | //! `outer_live_interval`, |
419 | //! Inner interval marks the exprs that are the first write and last read of |
420 | //! the buffer. |
421 | //! Outer interval marks the begining of the loop of first write and end of |
422 | //! the loop of last read, both at the same loop level as the buffer |
423 | //! allocation. |
424 | class BufferUseDefInfo { |
425 | public: |
426 | // Alias local memory if it exceeds this threshold |
427 | static constexpr long kRegisterSizeThreshold = 1; |
428 | |
429 | BufferUseDefInfo( |
430 | const std::vector<Expr*>& exprs, |
431 | BufferReuseDebugPrinter* debug_printer = nullptr) |
432 | : debug_printer_(debug_printer) { |
433 | if (debug_printer) { |
434 | debug_printer->buffer_info_ = this; |
435 | } |
436 | collectScopeInfo(exprs); |
437 | collectScopeUseDefInfo(exprs); |
438 | } |
439 | |
440 | //! Returns live interval info of buffer if previously |
441 | //! computed. |
442 | c10::optional<AllocationInfoPtr> getMaybeReuseInfoFor( |
443 | kir::Allocate* allocate) const { |
444 | auto alloc_it = map_allocate_to_info_.find(allocate); |
445 | if (alloc_it == map_allocate_to_info_.end()) { |
446 | return c10::nullopt; |
447 | } |
448 | auto alloc = alloc_it->second; |
449 | return alloc; |
450 | } |
451 | |
452 | //! Realize alias of two buffers through inner alias analysis and |
453 | //! keep track of the re-use. |
454 | void useInnerAlias(AllocationInfoPtr from, AllocationInfoPtr to) { |
455 | to->inner_alias_list_->push_back(from); |
456 | to->inner_subscribed_intevals->push_back(from->inner_live_interval.get()); |
457 | setAlias(from, to); |
458 | from->is_inner_alias = true; |
459 | } |
460 | |
461 | //! Realize alias of two buffers through outer alias analysis and |
462 | //! keep track of the re-use. |
463 | void useOuterAlias(AllocationInfoPtr from, AllocationInfoPtr to) { |
464 | to->outer_subscribed_intevals->push_back(from->outer_live_interval.get()); |
465 | setAlias(from, to); |
466 | } |
467 | |
468 | //! To run before performing in-place sharing analysis. |
469 | //! Initializes the inner live intervals with each |
470 | //! allocation's inner live interval. |
471 | void prepareInnerSharingAnalysis() { |
472 | for (auto it : map_allocate_to_info_) { |
473 | auto alloc_info = it.second; |
474 | // At beginning only use interval for each |
475 | // allocate is their corresponding live interval |
476 | alloc_info->inner_subscribed_intevals->push_back( |
477 | alloc_info->inner_live_interval.get()); |
478 | } |
479 | } |
480 | |
481 | //! To run before performing outer interval based sharing analysis. |
482 | //! Initializes the outer live intervals with the outer live interval |
483 | //! of each allocation and copy inner sharing information. |
484 | void prepareOuterSharingAnalysis() { |
485 | for (auto it : map_allocate_to_info_) { |
486 | auto alloc_info = it.second; |
487 | if (!alias_map_.count(alloc_info)) { |
488 | alloc_info->outer_subscribed_intevals->push_back( |
489 | alloc_info->outer_live_interval.get()); |
490 | // Update only if this buffer isn't an alias |
491 | for (auto inner_alias : *(alloc_info->inner_alias_list_)) { |
492 | alloc_info->outer_subscribed_intevals->push_back( |
493 | inner_alias->outer_live_interval.get()); |
494 | } |
495 | } |
496 | } |
497 | } |
498 | |
499 | private: |
500 | void handle(Expr* expr) { |
501 | current_pos_++; |
502 | if (debug_printer_) { |
503 | debug_printer_->pushBack(current_pos_, expr); |
504 | } |
505 | if (auto alloc = dynamic_cast<kir::Allocate*>(expr)) { |
506 | handle(alloc); |
507 | } else if (auto for_loop = dynamic_cast<kir::ForLoop*>(expr)) { |
508 | handle(for_loop); |
509 | } else if (auto ite = dynamic_cast<kir::IfThenElse*>(expr)) { |
510 | handle(ite); |
511 | } else { |
512 | collectLivenessInfo(expr); |
513 | } |
514 | } |
515 | |
516 | void handleScope(const std::vector<Expr*>& exprs) { |
517 | if (debug_printer_) { |
518 | debug_printer_->pushScope(); |
519 | } |
520 | for (auto expr : exprs) { |
521 | handle(expr); |
522 | } |
523 | if (debug_printer_) { |
524 | debug_printer_->popScope(); |
525 | } |
526 | } |
527 | |
528 | void handle(kir::ForLoop* for_loop) { |
529 | auto loop_info = map_loop_pos_to_loop_info_.at(current_pos_); |
530 | current_stack_.push_back(loop_info); |
531 | handleScope(for_loop->body().exprs()); |
532 | current_stack_.pop_back(); |
533 | } |
534 | |
535 | void handle(kir::IfThenElse* ite) { |
536 | TORCH_INTERNAL_ASSERT( |
537 | false, "lower_alias_memory: no support for IfThenElse at this phase." ); |
538 | } |
539 | |
540 | // Generate allocation info for allocation after some pre-filtering |
541 | // conditions. |
542 | void handle(kir::Allocate* alloc) { |
543 | if (alloc->alias()) { |
544 | // We shouldn't really see a case like this in general, but |
545 | // some Fusion outputs could have been aliased to inputs. |
546 | // It should be safe to ignore these in the use-def analysis. |
547 | return; |
548 | } |
549 | |
550 | auto tv = dynamic_cast<TensorView*>(alloc->buffer()); |
551 | if (!tv) { |
552 | return; |
553 | } |
554 | |
555 | // Collect the allocate info data |
556 | |
557 | // Collect memory type, skip global buffers |
558 | auto mem_type = tv->getMemoryType(); |
559 | if (mem_type != MemoryType::Local && mem_type != MemoryType::Shared) { |
560 | return; |
561 | } |
562 | |
563 | // Skip smaller register sizes |
564 | bool should_try_alias = true; |
565 | if (mem_type == MemoryType::Local) { |
566 | const auto register_size = expr_evaluator_.evaluate(alloc->size()); |
567 | if (!register_size.has_value()) { |
568 | TORCH_WARN_ONCE( |
569 | "Lower_alias_memory : dynamic sized register allocation" ); |
570 | return; |
571 | } |
572 | if (register_size->as<int64_t>() <= kRegisterSizeThreshold) { |
573 | should_try_alias = false; |
574 | } |
575 | } |
576 | |
577 | auto data_type = tv->dtype(); |
578 | auto size_print = SymbolicSizePrinter::printSize(alloc); |
579 | |
580 | // Make sure we don't have conflicting information on record |
581 | TORCH_INTERNAL_ASSERT(!map_allocate_to_info_.count(alloc)); |
582 | TORCH_INTERNAL_ASSERT(!map_tv_to_allocations_.count(tv->name())); |
583 | |
584 | // make AllocationUseDefInfo: |
585 | auto alloc_info = makeUseDefInfo(); |
586 | alloc_info->alloc_expr = alloc; |
587 | alloc_info->mem_type = mem_type; |
588 | alloc_info->data_type = data_type; |
589 | alloc_info->size_expr = size_print; |
590 | alloc_info->loop_info = current_stack_.back(); |
591 | alloc_info->should_try_alias = should_try_alias; |
592 | |
593 | // record short cuts |
594 | map_allocate_to_info_[alloc] = alloc_info; |
595 | map_tv_to_allocations_[tv->name()] = alloc_info; |
596 | } |
597 | |
598 | void collectScopeUseDefInfo(const std::vector<Expr*>& exprs) { |
599 | // Reset position pointer |
600 | resetExprCounter(); |
601 | TORCH_INTERNAL_ASSERT(global_scope_info_ != nullptr); |
602 | current_stack_.push_back(global_scope_info_); |
603 | handleScope(exprs); |
604 | } |
605 | |
606 | void collectScopeInfo(const std::vector<Expr*>& exprs) { |
607 | // Reset position pointer |
608 | resetExprCounter(); |
609 | collectScopeInfoWithinLoop(exprs, nullptr); |
610 | } |
611 | |
612 | void collectScopeInfoWithinLoop( |
613 | const std::vector<Expr*>& exprs, |
614 | kir::ForLoop* current_loop) { |
615 | auto loop_info = makeScopeInfo(current_loop); |
616 | for (auto expr : exprs) { |
617 | current_pos_++; |
618 | if (auto for_loop = dynamic_cast<kir::ForLoop*>(expr)) { |
619 | collectScopeInfoWithinLoop(for_loop->body().exprs(), for_loop); |
620 | } |
621 | } |
622 | loop_info->end_pos = current_pos_ + 1; |
623 | } |
624 | |
625 | void resetExprCounter() { |
626 | current_pos_ = -1; |
627 | } |
628 | |
629 | // Iterate over the inputs and outputs of exprs and update |
630 | // the liveness info of local buffers if applicaable. |
631 | void collectLivenessInfo(const Expr* expr) { |
632 | if (!ir_utils::isTvOp(expr)) { |
633 | return; |
634 | } |
635 | |
636 | auto out_tv = expr->outputs()[0]->as<TensorView>(); |
637 | |
638 | // Collect all tv's that resolves broadcast in this |
639 | // expr. The current analysis isn't enough to capture |
640 | // their liveness range. |
641 | for (auto input_tv : ir_utils::filterByType<TensorView>(expr->inputs())) { |
642 | auto maybe_alloc_info = getMaybeAllocInfoFromTV(input_tv); |
643 | if (maybe_alloc_info.has_value()) { |
644 | if (!isSerialBroadcastResolution(input_tv, out_tv)) { |
645 | maybe_alloc_info.value()->inner_live_interval->markRead(current_pos_); |
646 | } else { |
647 | // Disable inner alias info for this buffer, since line number based |
648 | // analysis is no longer precise enough for inplace sharing |
649 | // if a serial broadcast is realized. |
650 | maybe_alloc_info.value()->can_use_inner_alias = false; |
651 | } |
652 | |
653 | auto outer_loop_info = |
654 | ascendLoopNestToSameLevelAs(maybe_alloc_info.value()); |
655 | |
656 | if (outer_loop_info) { |
657 | maybe_alloc_info.value()->outer_live_interval->markRead( |
658 | outer_loop_info->end_pos); |
659 | } else { |
660 | // Allocate is inlined in the innermost loop, |
661 | // so outer live interval is the same as inner. |
662 | maybe_alloc_info.value()->outer_live_interval->markRead(current_pos_); |
663 | } |
664 | } |
665 | } |
666 | for (auto output_tv : ir_utils::filterByType<TensorView>(expr->outputs())) { |
667 | auto maybe_alloc_info = getMaybeAllocInfoFromTV(output_tv); |
668 | if (maybe_alloc_info.has_value()) { |
669 | maybe_alloc_info.value()->inner_live_interval->markWrite(current_pos_); |
670 | auto outer_loop_info = |
671 | ascendLoopNestToSameLevelAs(maybe_alloc_info.value()); |
672 | if (outer_loop_info) { |
673 | maybe_alloc_info.value()->outer_live_interval->markWrite( |
674 | outer_loop_info->start_pos); |
675 | } else { |
676 | maybe_alloc_info.value()->outer_live_interval->markWrite( |
677 | current_pos_); |
678 | } |
679 | } |
680 | } |
681 | } |
682 | |
683 | //! Find the loop level of expr that apears in the same scope as |
684 | //! the reference allocate. Eg. |
685 | //! |
686 | //! For ... |
687 | //! For ... |
688 | //! Allocate <---- reference arg |
689 | //! For .. |
690 | //! For ... |
691 | //! For ... <---- this function returns `ScopeInfo` for this loop |
692 | //! For ... |
693 | //! expr <---- current expr (implied in current_stack_ and |
694 | //! current_pos_ ) |
695 | //! Assumes that expr either writes to or reads from the reference allocate. |
696 | ScopeInfo* ascendLoopNestToSameLevelAs(AllocationUseDefInfo* reference) { |
697 | auto allocate_loop_info = reference->loop_info; |
698 | if (allocate_loop_info->loop == nullptr) { |
699 | if (current_stack_.size() > 1) { |
700 | return current_stack_[1]; |
701 | } |
702 | return nullptr; |
703 | } |
704 | |
705 | for (const auto idx : c10::irange(current_stack_.size() - 1)) { |
706 | if (current_stack_[idx] == allocate_loop_info) { |
707 | return current_stack_[idx + 1]; |
708 | } |
709 | } |
710 | |
711 | TORCH_INTERNAL_ASSERT( |
712 | current_stack_.back() == allocate_loop_info, |
713 | "lower_alias_memory : expr outer loop inconsistent with allocate" ); |
714 | |
715 | // Returning a nullptr means the allocate is in the current stack frame. |
716 | return nullptr; |
717 | } |
718 | |
719 | c10::optional<AllocationInfoPtr> getMaybeAllocInfoFromTV(TensorView* tv) { |
720 | auto alloc_it = map_tv_to_allocations_.find(tv->name()); |
721 | if (alloc_it == map_tv_to_allocations_.end()) { |
722 | return c10::nullopt; |
723 | } |
724 | return alloc_it->second; |
725 | } |
726 | |
727 | //! Factory function for internal loop information data |
728 | ScopeInfo* makeScopeInfo(kir::ForLoop* loop) { |
729 | auto loop_info_ptr = std::make_unique<ScopeInfo>(); |
730 | auto loop_info = loop_info_ptr.get(); |
731 | loop_info->start_pos = current_pos_; |
732 | loop_info->end_pos = -1; |
733 | loop_info->loop = loop; |
734 | all_loop_infos_.emplace_back(std::move(loop_info_ptr)); |
735 | |
736 | if (loop == nullptr) { |
737 | TORCH_INTERNAL_ASSERT( |
738 | !global_scope_info_, "Should only create global scope info once!" ); |
739 | global_scope_info_ = loop_info; |
740 | } else { |
741 | map_loop_pos_to_loop_info_[current_pos_] = loop_info; |
742 | } |
743 | return loop_info; |
744 | } |
745 | |
746 | //! Factory function for internal use-def information data |
747 | AllocationUseDefInfo* makeUseDefInfo() { |
748 | auto alloc_info_ptr = std::make_unique<AllocationUseDefInfo>(); |
749 | auto alloc_info = alloc_info_ptr.get(); |
750 | |
751 | alloc_info->alloc_pos = current_pos_; |
752 | alloc_info->inner_alias_list_ = |
753 | std::make_unique<std::vector<AllocationUseDefInfo*>>(); |
754 | alloc_info->inner_live_interval = std::make_unique<BufferLiveInterval>(); |
755 | alloc_info->inner_subscribed_intevals = |
756 | std::make_unique<BufferLiveIntervalPtrList>(); |
757 | alloc_info->outer_live_interval = std::make_unique<BufferLiveInterval>(); |
758 | alloc_info->outer_subscribed_intevals = |
759 | std::make_unique<BufferLiveIntervalPtrList>(); |
760 | all_allocations_.emplace_back(std::move(alloc_info_ptr)); |
761 | return alloc_info; |
762 | } |
763 | |
764 | // Realize buffer alias and keep track of the alias info. |
765 | void setAlias(AllocationInfoPtr from, AllocationInfoPtr to) { |
766 | alias_map_[from] = to; |
767 | from->alloc_expr->setAlias(to->alloc_expr); |
768 | from->alias_to = to->alloc_expr; |
769 | } |
770 | |
771 | private: |
772 | friend BufferReuseDebugPrinter; |
773 | friend class SerialBroadcastIntervalExpansion; |
774 | |
775 | //! Allocation sites that will participate in this analysis |
776 | std::unordered_map<const kir::Allocate*, AllocationInfoPtr> |
777 | map_allocate_to_info_; |
778 | |
779 | //! Map TensorView name to Allocate node. |
780 | //! Note: this assumes that each tensor view is only allocated once. |
781 | std::unordered_map<StmtNameType, AllocationInfoPtr> map_tv_to_allocations_; |
782 | |
783 | //! Keeps track of all the allocations that have been set to alias |
784 | std::unordered_map<AllocationInfoPtr, AllocationInfoPtr> alias_map_; |
785 | |
786 | //! Keep track of stack: |
787 | std::vector<ScopeInfo*> current_stack_; |
788 | |
789 | //! Contains start and end position of the global scope |
790 | ScopeInfo* global_scope_info_ = nullptr; |
791 | |
792 | //! map loop start position to loop info |
793 | std::unordered_map<int, ScopeInfo*> map_loop_pos_to_loop_info_; |
794 | |
795 | //! Owning list of collected allocation info |
796 | AllocationInfoOwningList all_allocations_; |
797 | |
798 | //! Owning list of collected allocation info |
799 | ScopeInfoOwningPtrList all_loop_infos_; |
800 | |
801 | //! Expression Evaluator to infer size of register allocation |
802 | kir::ExpressionEvaluator expr_evaluator_; |
803 | |
804 | //! Position counter when iterating through the exprs list |
805 | int current_pos_ = -1; |
806 | |
807 | //! Debug info: |
808 | BufferReuseDebugPrinter* debug_printer_ = nullptr; |
809 | }; |
810 | |
811 | void BufferReuseDebugPrinter::printAllocInfo(const kir::Allocate* alloc) { |
812 | TORCH_INTERNAL_ASSERT(buffer_info_ != nullptr); |
813 | std::string (" \033[1;32m^^^^^ ---Buffer Reuse Info--- " ); |
814 | std::string message_end(" \033[0m\n" ); |
815 | if (!buffer_info_->map_allocate_to_info_.count(alloc)) { |
816 | // This buffer is not considered for any sharing, either |
817 | // because of un-supported op or size below threshold. |
818 | return; |
819 | } |
820 | |
821 | auto alloc_info = buffer_info_->map_allocate_to_info_.at(alloc); |
822 | |
823 | indent() << message_header; |
824 | if (alloc_info->alias_to) { |
825 | if (alloc_info->is_inner_alias) { |
826 | os_ << "(inner) " ; |
827 | } else { |
828 | os_ << "(outer) " ; |
829 | } |
830 | os_ << " alias to alloc at pos " |
831 | << buffer_info_->getMaybeReuseInfoFor(alloc_info->alias_to) |
832 | .value() |
833 | ->alloc_pos |
834 | << " " ; |
835 | } else { |
836 | os_ << " not aliased " ; |
837 | } |
838 | |
839 | os_ << " , " ; |
840 | |
841 | if (alloc_info->can_use_inner_alias) { |
842 | os_ << "inner live interval: " ; |
843 | os_ << alloc_info->inner_live_interval->toString() << " , " ; |
844 | } |
845 | os_ << "size expr : " << alloc_info->size_expr << " , " |
846 | << "outer live interval: " << alloc_info->outer_live_interval->toString(); |
847 | indent() << message_end; |
848 | } |
849 | |
850 | //! Reuse Allocation nodes via pointer aliasing |
851 | class AllocateReuseModifier { |
852 | public: |
853 | static void modify(const std::vector<Expr*>& exprs) { |
854 | AllocateReuseModifier modifier(exprs); |
855 | } |
856 | |
857 | static void debugPrint(const std::vector<Expr*>& exprs) { |
858 | BufferReuseDebugPrinter debug_printer; |
859 | AllocateReuseModifier modifier(exprs, &debug_printer); |
860 | std::cout << debug_printer.dumpDebugInfo(); |
861 | } |
862 | |
863 | private: |
864 | AllocateReuseModifier( |
865 | const std::vector<Expr*>& exprs, |
866 | BufferReuseDebugPrinter* debug_printer_ = nullptr) |
867 | : buffer_info_(exprs, debug_printer_) { |
868 | // Perform in-place sharing first and then outer liveness |
869 | // based sharing. Since outer liveness info can still |
870 | // be used with some buffers already aliasing through |
871 | // in-place re-use but wouldn't be the case if we did |
872 | // outer liveness based sharing first. |
873 | buffer_info_.prepareInnerSharingAnalysis(); |
874 | handleScope(exprs); |
875 | |
876 | inner_aliasing_pass_ = false; |
877 | |
878 | buffer_info_.prepareOuterSharingAnalysis(); |
879 | handleScope(exprs); |
880 | } |
881 | |
882 | // Second visit of an allocate op |
883 | void handle(kir::Allocate* allocate) { |
884 | // Check that if this allocation site is one that |
885 | // we want to re-use or replace with an alias |
886 | |
887 | auto maybe_alloc_info = buffer_info_.getMaybeReuseInfoFor(allocate); |
888 | if (maybe_alloc_info.has_value() && |
889 | maybe_alloc_info.value()->alias_to == nullptr) { |
890 | // Try to re-use existing allocates |
891 | if (!tryReuseOtherAllocate(maybe_alloc_info.value())) { |
892 | // If didn't re-use, should register this |
893 | // allocate so that future allocates |
894 | // can re-use this one. |
895 | current_visible_buffer_stack_.back()->push_back( |
896 | maybe_alloc_info.value()); |
897 | } |
898 | } |
899 | } |
900 | |
901 | bool tryReuseOtherAllocate(AllocationInfoPtr alloc_info) { |
902 | if (!alloc_info->should_try_alias) { |
903 | return false; |
904 | } |
905 | if (!alloc_info->inner_alias_list_->empty()) { |
906 | // Avoid 2-hop aliasing for simplicity. Can support if really need in |
907 | // extreme cases. |
908 | return false; |
909 | } |
910 | |
911 | // Move backwards on list of re-usable allocates on the stack, prefer |
912 | // reusing nearest allocation |
913 | for (auto reuse_stack_it = current_visible_buffer_stack_.rbegin(); |
914 | reuse_stack_it != current_visible_buffer_stack_.rend(); |
915 | reuse_stack_it++) { |
916 | for (auto alloc_to_reuse_it = (*reuse_stack_it)->rbegin(); |
917 | alloc_to_reuse_it != (*reuse_stack_it)->rend(); |
918 | alloc_to_reuse_it++) { |
919 | auto alloc_to_reuse = *alloc_to_reuse_it; |
920 | |
921 | // Check if this re-use candidate is an alias |
922 | if (alloc_to_reuse->alias_to != nullptr) { |
923 | continue; |
924 | } |
925 | |
926 | // Check if this alloc has the same mem type |
927 | if (alloc_info->mem_type != alloc_to_reuse->mem_type) { |
928 | continue; |
929 | } |
930 | |
931 | // Check if this alloc has the same size |
932 | if (alloc_info->size_expr != alloc_to_reuse->size_expr) { |
933 | continue; |
934 | } |
935 | |
936 | // Check if this alloc has the same data type |
937 | if (alloc_info->data_type != alloc_to_reuse->data_type) { |
938 | continue; |
939 | } |
940 | |
941 | // Check if live intervals have any overlap |
942 | auto subscribed_intervals = inner_aliasing_pass_ |
943 | ? alloc_to_reuse->inner_subscribed_intevals.get() |
944 | : alloc_to_reuse->outer_subscribed_intevals.get(); |
945 | |
946 | auto alloc_live_interval = inner_aliasing_pass_ |
947 | ? alloc_info->inner_live_interval.get() |
948 | : alloc_info->outer_live_interval.get(); |
949 | |
950 | if (std::any_of( |
951 | subscribed_intervals->begin(), |
952 | subscribed_intervals->end(), |
953 | [alloc_live_interval](auto subscribed_interval) { |
954 | return alloc_live_interval->intersect(subscribed_interval); |
955 | })) { |
956 | continue; |
957 | } |
958 | |
959 | // Special checks for inner sharing pass |
960 | if (inner_aliasing_pass_ && |
961 | !isValidInnerSharing(alloc_to_reuse, alloc_info)) { |
962 | continue; |
963 | } |
964 | |
965 | if (alloc_info->alloc_expr->buffer()->isA<TensorView>()) { |
966 | if (!alloc_info->alloc_expr->buffer()->isA<TensorView>()) { |
967 | continue; |
968 | } |
969 | auto this_tv = alloc_info->alloc_expr->buffer()->as<TensorView>(); |
970 | auto reuse_tv = alloc_info->alloc_expr->buffer()->as<TensorView>(); |
971 | // Check that either both tv's are vectorized acceses, or neither are. |
972 | // Vectorized allocations require correct alignment so they can only |
973 | // alias with other allocations with the right alignment |
974 | const auto& va = GpuLower::current()->vectorizedAccesses(); |
975 | if ((va.find(this_tv) == va.end()) != |
976 | (va.find(reuse_tv) == va.end())) { |
977 | return false; |
978 | } |
979 | |
980 | // Shared memory is all aligned to 128 bits, local memory might not be |
981 | if (this_tv->getMemoryType() == MemoryType::Local && |
982 | va.find(this_tv) != va.end()) { |
983 | // Make sure alignment matches |
984 | if (va.at(this_tv) != va.at(reuse_tv)) { |
985 | return false; |
986 | } |
987 | } |
988 | } |
989 | |
990 | // TODO: |
991 | // Outer interval based sharing supports arbitrary re-indexing into |
992 | // the same buffer and would require additional syncs if fully |
993 | // enabled. |
994 | // Need a few more checks to insert syncs if necessary before turning |
995 | // on this sharing. |
996 | if (!inner_aliasing_pass_ && |
997 | alloc_info->mem_type == MemoryType::Shared) { |
998 | continue; |
999 | } |
1000 | |
1001 | // Now re-use the alloc here and be sure to update |
1002 | reUseAllocation(alloc_info, alloc_to_reuse); |
1003 | return true; |
1004 | } |
1005 | } |
1006 | return false; |
1007 | } |
1008 | |
1009 | void handle(Expr* expr) { |
1010 | if (auto ite = dynamic_cast<kir::IfThenElse*>(expr)) { |
1011 | handle(ite); |
1012 | } else if (auto for_loop = dynamic_cast<kir::ForLoop*>(expr)) { |
1013 | handle(for_loop); |
1014 | } else if (auto allocate = dynamic_cast<kir::Allocate*>(expr)) { |
1015 | handle(allocate); |
1016 | } |
1017 | } |
1018 | |
1019 | void handle(const kir::ForLoop* for_loop) { |
1020 | handleScope(for_loop->body().exprs()); |
1021 | } |
1022 | |
1023 | void handle(const kir::IfThenElse* for_loop) { |
1024 | TORCH_INTERNAL_ASSERT( |
1025 | false, |
1026 | "lower_alias_memory: IfThenElse before unrolling is not yet supported" ); |
1027 | } |
1028 | |
1029 | void handleScope(const std::vector<Expr*>& exprs) { |
1030 | current_visible_buffer_stack_.emplace_back( |
1031 | std::make_unique<AllocationInfoList>()); |
1032 | for (auto expr : exprs) { |
1033 | handle(expr); |
1034 | } |
1035 | current_visible_buffer_stack_.pop_back(); |
1036 | } |
1037 | |
1038 | struct InPlaceSharingInfo { |
1039 | bool has_broadcast_between = false; |
1040 | bool has_unsupported_op = false; |
1041 | }; |
1042 | |
1043 | //! Careful heavy check on inner sharing candidates, |
1044 | //! current enforced conditions are: |
1045 | //! |
1046 | //! 1. The two buffers have producer-consumer relationship |
1047 | //! 2. No halo in the allocated iter domains |
1048 | //! 3. Require index equivalence when sharing across broadcast |
1049 | bool isValidInnerSharing( |
1050 | AllocationUseDefInfo* alloc_info, |
1051 | AllocationUseDefInfo* to_reuse) { |
1052 | // Disable if either of the buffers do not support inner sharing |
1053 | if (!alloc_info->can_use_inner_alias || !to_reuse->can_use_inner_alias) { |
1054 | return false; |
1055 | } |
1056 | // Assume inputs are TV allocations, which should have been checked |
1057 | // before reaching this point. |
1058 | auto this_tv = alloc_info->alloc_expr->buffer()->as<TensorView>(); |
1059 | auto reuse_tv = to_reuse->alloc_expr->buffer()->as<TensorView>(); |
1060 | |
1061 | // Aggressively disable inner sharing for swizzled tvs since |
1062 | // the indexing order is in general not tractable. |
1063 | // But outer sharing should still apply. |
1064 | if (this_tv->hasSwizzleOp() || reuse_tv->hasSwizzleOp()) { |
1065 | return false; |
1066 | } |
1067 | |
1068 | // Check the values in between the two buffers. |
1069 | auto vals_between_this_and_reuse = |
1070 | DependencyCheck::getAllValsBetween({this_tv}, {reuse_tv}); |
1071 | if (vals_between_this_and_reuse.empty()) { |
1072 | vals_between_this_and_reuse = |
1073 | DependencyCheck::getAllValsBetween({reuse_tv}, {this_tv}); |
1074 | } |
1075 | |
1076 | if (!vals_between_this_and_reuse.empty()) { |
1077 | // Temporarily disable sharing across difficult |
1078 | // ops for inner sharing and can be relaxed gradually. |
1079 | auto topo_info = checkOpsInBetween(vals_between_this_and_reuse); |
1080 | |
1081 | // Avoid difficult and future introduced ops |
1082 | if (topo_info.has_unsupported_op) { |
1083 | return false; |
1084 | } |
1085 | |
1086 | // Get information on the allocated domains of the |
1087 | // two buffers |
1088 | auto& local_alloc_map = GpuLower::current()->localAllocationInfoMap(); |
1089 | auto alloc_it = local_alloc_map.find(alloc_info->alloc_expr); |
1090 | auto to_reuse_it = local_alloc_map.find(to_reuse->alloc_expr); |
1091 | if (alloc_it == local_alloc_map.end() || |
1092 | to_reuse_it == local_alloc_map.end()) { |
1093 | return false; |
1094 | } |
1095 | |
1096 | // Disable in-place reusing for halo ops, since halo |
1097 | // can issue pointwise op multiple points at some points. |
1098 | if (alloc_it->second->has_halo || to_reuse_it->second->has_halo) { |
1099 | return false; |
1100 | } |
1101 | |
1102 | // Require matched iterdomains for sharing across broadcast |
1103 | if (topo_info.has_broadcast_between) { |
1104 | auto& alloc_domains = alloc_it->second->alloc_domains; |
1105 | auto& reuse_domains = to_reuse_it->second->alloc_domains; |
1106 | |
1107 | return allocationDomainsIndexMapped(alloc_domains, reuse_domains); |
1108 | } |
1109 | |
1110 | // If only pointwise and reduction ops in between and no broadcast |
1111 | // should be ok to re-use in place. |
1112 | return true; |
1113 | } |
1114 | |
1115 | // this and reuse are not dependencies of each other, |
1116 | // which means we cannot use inner sharing. |
1117 | return false; |
1118 | } |
1119 | |
1120 | InPlaceSharingInfo checkOpsInBetween(std::vector<Val*>& all_used_vals) { |
1121 | InPlaceSharingInfo info; |
1122 | std::unordered_set<Val*> all_used_val_set( |
1123 | all_used_vals.begin(), all_used_vals.end()); |
1124 | |
1125 | for (auto val : all_used_vals) { |
1126 | if (auto tv = dynamic_cast<TensorView*>(val)) { |
1127 | auto tv_def = tv->definition(); |
1128 | if (!tv_def) { |
1129 | continue; |
1130 | } |
1131 | if (!isPointwiseTvOp(tv_def) && !ir_utils::isReductionTvOp(tv_def)) { |
1132 | if (isBroadcastTvOp(tv_def)) { |
1133 | info.has_broadcast_between = true; |
1134 | } else { |
1135 | info.has_unsupported_op = true; |
1136 | } |
1137 | } |
1138 | } |
1139 | } |
1140 | return info; |
1141 | } |
1142 | |
1143 | bool allocationDomainsIndexMapped( |
1144 | std::vector<IterDomain*>& alloc_domains, |
1145 | std::vector<IterDomain*>& reuse_domains) { |
1146 | // Require that the allocated domains are exactly mapped. |
1147 | if (alloc_domains.size() != reuse_domains.size()) { |
1148 | return false; |
1149 | } |
1150 | |
1151 | // Check index map for the corresponding axes. |
1152 | for (const auto id_it : c10::irange(alloc_domains.size())) { |
1153 | if (!GpuLower::current()->caMap()->areMapped( |
1154 | alloc_domains[id_it], |
1155 | reuse_domains[id_it], |
1156 | IdMappingMode::EXACT)) { |
1157 | return false; |
1158 | } |
1159 | } |
1160 | return true; |
1161 | } |
1162 | |
1163 | void reUseAllocation( |
1164 | AllocationUseDefInfo* alloc_info, |
1165 | AllocationUseDefInfo* to_reuse) { |
1166 | // Update analysis result |
1167 | if (inner_aliasing_pass_) { |
1168 | buffer_info_.useInnerAlias(alloc_info, to_reuse); |
1169 | } else { |
1170 | buffer_info_.useOuterAlias(alloc_info, to_reuse); |
1171 | } |
1172 | } |
1173 | |
1174 | // Do we have a true pointwise op? |
1175 | // (ie. a TV op, excluding direct assignments and reductions) |
1176 | bool isPointwiseTvOp(const Expr* expr) { |
1177 | if (ir_utils::isTvOp(expr)) { |
1178 | return expr->isA<UnaryOp>() || expr->isA<BinaryOp>() || |
1179 | expr->isA<TernaryOp>(); |
1180 | } |
1181 | return false; |
1182 | } |
1183 | |
1184 | // Utility to capture reduction ops |
1185 | bool isBroadcastTvOp(const Expr* expr) { |
1186 | if (!ir_utils::isTvOp(expr)) { |
1187 | return false; |
1188 | } |
1189 | return expr->isA<BroadcastOp>(); |
1190 | } |
1191 | |
1192 | private: |
1193 | // Analysis result from the first pass collecting the use-defs |
1194 | BufferUseDefInfo buffer_info_; |
1195 | |
1196 | // Internal data keeping track of currently visible allocations as |
1197 | // the pass iterate through the expr list, grouped by the stack |
1198 | // layer of alloc ops. |
1199 | std::vector<std::unique_ptr<AllocationInfoList>> |
1200 | current_visible_buffer_stack_; |
1201 | |
1202 | // Marks state of current pass |
1203 | bool inner_aliasing_pass_ = true; |
1204 | }; |
1205 | |
1206 | } // namespace |
1207 | |
1208 | std::vector<Expr*> reuseMemoryAllocations(const std::vector<Expr*>& exprs) { |
1209 | FUSER_PERF_SCOPE("reuseMemoryAllocations" ); |
1210 | bool debug_print = isDebugDumpEnabled(DebugDumpOption::BufferReuseInfo); |
1211 | if (debug_print) { |
1212 | AllocateReuseModifier::debugPrint(exprs); |
1213 | } |
1214 | AllocateReuseModifier::modify(exprs); |
1215 | return exprs; |
1216 | } |
1217 | |
1218 | } // namespace cuda |
1219 | } // namespace fuser |
1220 | } // namespace jit |
1221 | } // namespace torch |
1222 | |