1#pragma once
2
3#include <c10/macros/Export.h>
4
5#include <ir_all_nodes.h>
6#include <kernel_ir.h>
7
8#include <vector>
9
10namespace torch {
11namespace jit {
12namespace fuser {
13namespace cuda {
14
15//! Buffer allocation information to store in GPU lower to avoid
16//! logic duplication
17struct LocalAllocationInfo {
18 kir::Allocate* alloc_expr = nullptr;
19 std::vector<IterDomain*> alloc_domains;
20 bool has_halo = false;
21};
22
23using LocalAllocationInfoMap =
24 std::unordered_map<kir::Allocate*, std::unique_ptr<LocalAllocationInfo>>;
25
26//! Insert buffer allocations
27std::vector<Expr*> insertAllocations(const std::vector<Expr*>& exprs);
28
29} // namespace cuda
30} // namespace fuser
31} // namespace jit
32} // namespace torch
33