1 | // Virtual memory allocator for CPU/GPU |
---|---|
2 | |
3 | #if defined(TI_WITH_CUDA) |
4 | #include "taichi/rhi/cuda/cuda_driver.h" |
5 | #include "taichi/rhi/cuda/cuda_context.h" |
6 | #include "taichi/rhi/cuda/cuda_device.h" |
7 | |
8 | #endif |
9 | #include "taichi/util/lang_util.h" |
10 | #include "taichi/system/unified_allocator.h" |
11 | #include "taichi/system/virtual_memory.h" |
12 | #include "taichi/system/timer.h" |
13 | #include "taichi/rhi/cpu/cpu_device.h" |
14 | #include <string> |
15 | |
16 | namespace taichi::lang { |
17 | |
18 | UnifiedAllocator::UnifiedAllocator(std::size_t size, Arch arch, Device *device) |
19 | : size_(size), arch_(arch), device_(device) { |
20 | auto t = Time::get_time(); |
21 | if (arch_ == Arch::x64) { |
22 | #ifdef TI_WITH_LLVM |
23 | Device::AllocParams alloc_params; |
24 | alloc_params.size = size; |
25 | alloc_params.host_read = true; |
26 | alloc_params.host_write = true; |
27 | |
28 | cpu::CpuDevice *cpu_device = static_cast<cpu::CpuDevice *>(device); |
29 | alloc = cpu_device->allocate_memory(alloc_params); |
30 | data = (uint8 *)cpu_device->get_alloc_info(alloc).ptr; |
31 | #else |
32 | TI_NOT_IMPLEMENTED |
33 | #endif |
34 | } else { |
35 | TI_TRACE("Allocating virtual address space of size {} MB", |
36 | size / 1024 / 1024); |
37 | cpu_vm_ = std::make_unique<VirtualMemoryAllocator>(size); |
38 | data = (uint8 *)cpu_vm_->ptr; |
39 | } |
40 | TI_ASSERT(data != nullptr); |
41 | TI_ASSERT(uint64(data) % 4096 == 0); |
42 | |
43 | head = data; |
44 | tail = head + size; |
45 | TI_TRACE("Memory allocated. Allocation time = {:.3} s", Time::get_time() - t); |
46 | } |
47 | |
48 | taichi::lang::UnifiedAllocator::~UnifiedAllocator() { |
49 | if (!initialized()) { |
50 | return; |
51 | } |
52 | if (arch_ == Arch::x64) { |
53 | cpu::CpuDevice *cpu_device = static_cast<cpu::CpuDevice *>(device_); |
54 | cpu_device->dealloc_memory(alloc); |
55 | } |
56 | } |
57 | |
58 | void taichi::lang::UnifiedAllocator::memset(unsigned char val) { |
59 | std::memset(data, val, size_); |
60 | } |
61 | |
62 | } // namespace taichi::lang |
63 |