1 | #pragma once |
---|---|
2 | |
3 | #include <c10/macros/Export.h> |
4 | |
5 | #include <ir_all_nodes.h> |
6 | #include <kernel_ir.h> |
7 | |
8 | #include <vector> |
9 | |
10 | namespace torch { |
11 | namespace jit { |
12 | namespace fuser { |
13 | namespace cuda { |
14 | |
15 | //! Buffer allocation information to store in GPU lower to avoid |
16 | //! logic duplication |
17 | struct LocalAllocationInfo { |
18 | kir::Allocate* alloc_expr = nullptr; |
19 | std::vector<IterDomain*> alloc_domains; |
20 | bool has_halo = false; |
21 | }; |
22 | |
23 | using LocalAllocationInfoMap = |
24 | std::unordered_map<kir::Allocate*, std::unique_ptr<LocalAllocationInfo>>; |
25 | |
26 | //! Insert buffer allocations |
27 | std::vector<Expr*> insertAllocations(const std::vector<Expr*>& exprs); |
28 | |
29 | } // namespace cuda |
30 | } // namespace fuser |
31 | } // namespace jit |
32 | } // namespace torch |
33 |