1 | #pragma once |
2 | |
3 | #include <c10/core/Allocator.h> |
4 | #include <c10/cuda/CUDAGraphsC10Utils.h> |
5 | #include <c10/cuda/CUDAMacros.h> |
6 | #include <c10/cuda/CUDAStream.h> |
7 | |
8 | #include <c10/cuda/CUDACachingAllocator.h> |
9 | |
10 | #include <array> |
11 | #include <mutex> |
12 | |
13 | namespace torch { |
14 | |
15 | namespace cuda { |
16 | |
17 | namespace CUDAPluggableAllocator { |
18 | |
19 | #if defined(TORCH_HIP_VERSION) |
20 | using streamType = c10::hip::HIPStream; |
21 | #else |
22 | using streamType = c10::cuda::CUDAStream; |
23 | #endif |
24 | |
25 | std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator> |
26 | getCurrentAllocator(); |
27 | std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator> |
28 | createCustomAllocator( |
29 | std::function<void*(size_t, int, cudaStream_t)> alloc_fn, |
30 | std::function<void(void*, size_t, int, cudaStream_t)> free_fn); |
31 | void changeCurrentAllocator( |
32 | std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator> allocator); |
33 | |
34 | struct _AllocationMetadata { |
35 | _AllocationMetadata(); |
36 | _AllocationMetadata(size_t size, int device_idx, cudaStream_t stream); |
37 | size_t size; |
38 | int device_idx; |
39 | cudaStream_t stream; |
40 | }; |
41 | |
42 | struct CUDAPluggableAllocator |
43 | : public c10::cuda::CUDACachingAllocator::CUDAAllocator { |
44 | CUDAPluggableAllocator( |
45 | std::function<void*(size_t, int, cudaStream_t)> alloc_fn, |
46 | std::function<void(void*, size_t, int, cudaStream_t)> free_fn); |
47 | |
48 | CUDAPluggableAllocator(CUDAPluggableAllocator& other); |
49 | |
50 | void set_init_fn(std::function<void(int)> init_fn); |
51 | |
52 | void set_reset_fn(std::function<void()> reset_fn); |
53 | |
54 | void set_memory_fraction_fn( |
55 | std::function<void(double, int)> memory_fraction_fn); |
56 | |
57 | void set_base_alloc_fn(std::function<void*(void*, size_t*)> base_alloc_fn); |
58 | |
59 | void set_record_stream_fn( |
60 | std::function<void(void* ptr, cudaStream_t stream)> record_stream_fn); |
61 | |
62 | void set_capture_begin_fn( |
63 | std::function<void(int, c10::cuda::CaptureId_t, c10::cuda::MempoolId_t)> |
64 | capture_begin_fn); |
65 | |
66 | void set_capture_about_to_end_fn( |
67 | std::function<void(int, c10::cuda::CaptureId_t)> capture_about_to_end_fn); |
68 | |
69 | void set_capture_ended_fn( |
70 | std::function<void(int, c10::cuda::CaptureId_t)> capture_ended_fn); |
71 | |
72 | void set_capture_destroy_fn( |
73 | std::function<void(int, c10::cuda::MempoolId_t)> capture_destroy_fn); |
74 | |
75 | void* malloc(size_t size, int device, cudaStream_t stream); |
76 | |
77 | c10::DataPtr allocate(size_t size) const override; |
78 | c10::DeleterFnPtr raw_deleter() const override; |
79 | |
80 | virtual void* raw_alloc(size_t nbytes) override; |
81 | virtual void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) |
82 | override; |
83 | virtual void raw_delete(void* ptr) override; |
84 | virtual void init(int device_count) override; |
85 | virtual bool initialized() override; |
86 | virtual void setMemoryFraction(double fraction, int device) override; |
87 | virtual void emptyCache() override; |
88 | virtual void cacheInfo(int dev_id, size_t* largestBlock) override; |
89 | virtual void* getBaseAllocation(void* ptr, size_t* size) override; |
90 | |
91 | virtual void recordStream(const c10::DataPtr&, streamType stream) override; |
92 | |
93 | virtual c10::cuda::CUDACachingAllocator::DeviceStats getDeviceStats( |
94 | int device) override; |
95 | virtual void resetAccumulatedStats(int device) override; |
96 | virtual void resetPeakStats(int device) override; |
97 | virtual c10::cuda::CUDACachingAllocator::SnapshotInfo snapshot() override; |
98 | virtual void notifyCaptureBegin( |
99 | int device, |
100 | c10::cuda::CaptureId_t graph_id, |
101 | c10::cuda::MempoolId_t mempool_id) override; |
102 | virtual void notifyCaptureAboutToEnd( |
103 | int device, |
104 | c10::cuda::CaptureId_t graph_id) override; |
105 | virtual void notifyCaptureEnded(int device, c10::cuda::CaptureId_t graph_id) |
106 | override; |
107 | virtual void notifyCaptureDestroy( |
108 | int device, |
109 | c10::cuda::MempoolId_t mempool_id) override; |
110 | virtual std::shared_ptr<void> getIpcDevPtr(std::string handle) override; |
111 | virtual void recordHistory( |
112 | bool enabled, |
113 | c10::cuda::CUDACachingAllocator::CreateContextFn context_recorder, |
114 | size_t alloc_trace_max_entries, |
115 | bool alloc_trace_record_context) override; |
116 | virtual void attachOutOfMemoryObserver( |
117 | c10::cuda::CUDACachingAllocator::OutOfMemoryObserver observer) override; |
118 | virtual bool needsPoolSpecificPeerAccess() override; |
119 | virtual std::string name() override; |
120 | |
121 | protected: |
122 | std::function<void*(size_t, int, cudaStream_t)> alloc_fn_; |
123 | std::function<void(void*, size_t, int, cudaStream_t)> free_fn_; |
124 | std::function<void(int)> init_fn_; |
125 | std::function<void()> reset_fn_; |
126 | std::function<void(double, int)> memory_fraction_fn_; |
127 | std::function<void*(void*, size_t*)> base_alloc_fn_; |
128 | std::function<void(void* ptr, cudaStream_t stream)> record_stream_fn_; |
129 | std::function<void(int, c10::cuda::CaptureId_t, c10::cuda::MempoolId_t)> |
130 | capture_begin_fn_; |
131 | std::function<void(int, c10::cuda::CaptureId_t)> capture_about_to_end_fn_; |
132 | std::function<void(int, c10::cuda::CaptureId_t)> capture_ended_fn_; |
133 | std::function<void(int, c10::cuda::MempoolId_t)> capture_destroy_fn_; |
134 | std::mutex allocator_mutex_; |
135 | // We do the bookeeping here in order to simplify custom allocators |
136 | std::unordered_map<void*, _AllocationMetadata> allocation_metadata_; |
137 | |
138 | bool initialized_ = false; |
139 | }; |
140 | } // namespace CUDAPluggableAllocator |
141 | } // namespace cuda |
142 | } // namespace torch |
143 | |