1#include <mutex>
2#include <unordered_map>
3#include <utility>
4
5#include <torch/csrc/cuda/CUDAPluggableAllocator.h>
6
7namespace torch {
8namespace cuda {
9namespace CUDAPluggableAllocator {
10
11int device_count = 0;
12
13void custom_raw_deleter(void* ptr);
14
15_AllocationMetadata::_AllocationMetadata()
16 : size(0), device_idx(-1), stream(0) {}
17
18_AllocationMetadata::_AllocationMetadata(
19 size_t size,
20 int device_idx,
21 cudaStream_t stream)
22 : size(size), device_idx(device_idx), stream(stream) {}
23
24// This is a fast API to just register allocators
25// based on function pointers (ie. external .so libraries)
26// This avoids having to link against libtorch for C++ based custom allocators
27// And also use this from python
28CUDAPluggableAllocator::CUDAPluggableAllocator(
29 std::function<void*(size_t, int, cudaStream_t)> alloc_fn,
30 std::function<void(void*, size_t, int, cudaStream_t)> free_fn)
31 : alloc_fn_(alloc_fn), free_fn_(free_fn) {}
32
33CUDAPluggableAllocator::CUDAPluggableAllocator(CUDAPluggableAllocator& other)
34 : alloc_fn_(other.alloc_fn_),
35 free_fn_(other.free_fn_),
36 init_fn_(other.init_fn_),
37 reset_fn_(other.reset_fn_),
38 memory_fraction_fn_(other.memory_fraction_fn_),
39 base_alloc_fn_(other.base_alloc_fn_),
40 record_stream_fn_(other.record_stream_fn_),
41 capture_begin_fn_(other.capture_begin_fn_),
42 capture_about_to_end_fn_(other.capture_about_to_end_fn_),
43 capture_ended_fn_(other.capture_ended_fn_),
44 capture_destroy_fn_(other.capture_destroy_fn_) {}
45
46void CUDAPluggableAllocator::set_init_fn(std::function<void(int)> init_fn) {
47 init_fn_ = init_fn;
48}
49
50void CUDAPluggableAllocator::set_reset_fn(std::function<void()> reset_fn) {
51 reset_fn_ = reset_fn;
52}
53
54void CUDAPluggableAllocator::set_memory_fraction_fn(
55 std::function<void(double, int)> memory_fraction_fn) {
56 memory_fraction_fn_ = memory_fraction_fn;
57}
58
59void CUDAPluggableAllocator::set_base_alloc_fn(
60 std::function<void*(void*, size_t*)> base_alloc_fn) {
61 base_alloc_fn_ = base_alloc_fn;
62}
63
64void CUDAPluggableAllocator::set_record_stream_fn(
65 std::function<void(void* ptr, cudaStream_t stream)> record_stream_fn) {
66 record_stream_fn_ = record_stream_fn;
67}
68
69void CUDAPluggableAllocator::set_capture_begin_fn(
70 std::function<void(int, c10::cuda::CaptureId_t, c10::cuda::MempoolId_t)>
71 capture_begin_fn) {
72 capture_begin_fn_ = capture_begin_fn;
73}
74
75void CUDAPluggableAllocator::set_capture_about_to_end_fn(
76 std::function<void(int, c10::cuda::CaptureId_t)> capture_about_to_end_fn) {
77 capture_about_to_end_fn_ = capture_about_to_end_fn;
78}
79
80void CUDAPluggableAllocator::set_capture_ended_fn(
81 std::function<void(int, c10::cuda::CaptureId_t)> capture_ended_fn) {
82 capture_ended_fn_ = capture_ended_fn;
83}
84
85void CUDAPluggableAllocator::set_capture_destroy_fn(
86 std::function<void(int, c10::cuda::MempoolId_t)> capture_destroy_fn) {
87 capture_destroy_fn_ = capture_destroy_fn;
88}
89
90void* CUDAPluggableAllocator::malloc(
91 size_t size,
92 int device,
93 cudaStream_t stream) {
94 void* r = alloc_fn_(size, device, stream);
95 {
96 const std::lock_guard<std::mutex> lock(allocator_mutex_);
97 allocation_metadata_.emplace(r, _AllocationMetadata(size, device, stream));
98 }
99 return r;
100}
101
102c10::DataPtr CUDAPluggableAllocator::allocate(size_t size) const {
103 int device;
104 C10_CUDA_CHECK(cudaGetDevice(&device));
105 cudaStream_t stream = c10::cuda::getCurrentCUDAStream(device);
106 void* r =
107 const_cast<CUDAPluggableAllocator*>(this)->malloc(size, device, stream);
108 c10::DataPtr data_ptr = {
109 r, r, raw_deleter(), c10::Device(c10::DeviceType::CUDA, device)};
110 return data_ptr;
111}
112
113c10::DeleterFnPtr CUDAPluggableAllocator::raw_deleter() const {
114 return &custom_raw_deleter;
115}
116
117void* CUDAPluggableAllocator::raw_alloc(size_t nbytes) {
118 int device;
119 C10_CUDA_CHECK(cudaGetDevice(&device));
120 cudaStream_t stream = c10::cuda::getCurrentCUDAStream(device);
121 return malloc(nbytes, device, stream);
122}
123
124void* CUDAPluggableAllocator::raw_alloc_with_stream(
125 size_t nbytes,
126 cudaStream_t stream) {
127 int device;
128 C10_CUDA_CHECK(cudaGetDevice(&device));
129 return malloc(nbytes, device, stream);
130}
131
132void CUDAPluggableAllocator::raw_delete(void* ptr) {
133 cudaStream_t stream;
134 int device_idx;
135 size_t size;
136 {
137 const std::lock_guard<std::mutex> lock(allocator_mutex_);
138 TORCH_CHECK(
139 allocation_metadata_.count(ptr),
140 "Trying to free a pointer not allocated here");
141 _AllocationMetadata& metadata = allocation_metadata_[ptr];
142 size = metadata.size;
143 device_idx = metadata.device_idx;
144 stream = metadata.stream;
145 allocation_metadata_.erase(ptr);
146 }
147 free_fn_(ptr, size, device_idx, stream);
148}
149
150void CUDAPluggableAllocator::init(int device_count) {
151 if (init_fn_) {
152 init_fn_(device_count);
153 }
154 initialized_ = true;
155}
156
157bool CUDAPluggableAllocator::initialized() {
158 return initialized_;
159}
160
161void CUDAPluggableAllocator::setMemoryFraction(double fraction, int device) {
162 if (memory_fraction_fn_) {
163 memory_fraction_fn_(fraction, device);
164 }
165}
166
167void CUDAPluggableAllocator::emptyCache(void) {
168 if (reset_fn_) {
169 return reset_fn_();
170 }
171}
172
173void CUDAPluggableAllocator::cacheInfo(int dev_id, size_t* largestBlock) {
174 TORCH_CHECK(
175 false,
176 "CUDAPluggableAllocator does not yet support cacheInfo. "
177 "If you need it, please file an issue describing your use case.");
178}
179
180void* CUDAPluggableAllocator::getBaseAllocation(void* ptr, size_t* size) {
181 if (base_alloc_fn_) {
182 return base_alloc_fn_(ptr, size);
183 } else {
184 return ptr;
185 }
186}
187
188void CUDAPluggableAllocator::recordStream(
189 const c10::DataPtr& ptr,
190 streamType stream) {
191 if (record_stream_fn_) {
192 record_stream_fn_(ptr.get(), stream);
193 }
194}
195
196c10::cuda::CUDACachingAllocator::DeviceStats CUDAPluggableAllocator::
197 getDeviceStats(int device) {
198 TORCH_CHECK(
199 false,
200 "CUDAPluggableAllocator does not yet support getDeviceStats. "
201 "If you need it, please file an issue describing your use case.");
202}
203
204void CUDAPluggableAllocator::resetAccumulatedStats(int device) {
205 TORCH_CHECK(
206 false,
207 "CUDAPluggableAllocator does not yet support resetAccumulatedStats. "
208 "If you need it, please file an issue describing your use case.");
209}
210
211void CUDAPluggableAllocator::resetPeakStats(int device) {
212 TORCH_CHECK(
213 false,
214 "CUDAPluggableAllocator does not yet support resetPeakStats. "
215 "If you need it, please file an issue describing your use case.");
216}
217
218c10::cuda::CUDACachingAllocator::SnapshotInfo CUDAPluggableAllocator::
219 snapshot() {
220 TORCH_CHECK(
221 false,
222 "CUDAPluggableAllocator does not yet support snapshot. "
223 "If you need it, please file an issue describing your use case.");
224}
225
226std::shared_ptr<void> CUDAPluggableAllocator::getIpcDevPtr(std::string handle) {
227 TORCH_CHECK(
228 false,
229 "CUDAPluggableAllocator does not yet support getIpcDevPtr. "
230 "If you need it, please file an issue describing your use case.");
231}
232
233// CUDAGraph interactions
234void CUDAPluggableAllocator::notifyCaptureBegin(
235 int device,
236 c10::cuda::CaptureId_t graph_id,
237 c10::cuda::MempoolId_t mempool_id) {
238 if (capture_begin_fn_) {
239 capture_begin_fn_(device, graph_id, mempool_id);
240 }
241}
242
243void CUDAPluggableAllocator::notifyCaptureAboutToEnd(
244 int device,
245 c10::cuda::CaptureId_t graph_id) {
246 if (capture_about_to_end_fn_) {
247 capture_about_to_end_fn_(device, graph_id);
248 }
249}
250
251void CUDAPluggableAllocator::notifyCaptureEnded(
252 int device,
253 c10::cuda::CaptureId_t graph_id) {
254 if (capture_ended_fn_) {
255 capture_ended_fn_(device, graph_id);
256 }
257}
258
259void CUDAPluggableAllocator::notifyCaptureDestroy(
260 int device,
261 c10::cuda::MempoolId_t mempool_id) {
262 if (capture_destroy_fn_) {
263 capture_destroy_fn_(device, mempool_id);
264 }
265}
266
267void CUDAPluggableAllocator::recordHistory(
268 bool enabled,
269 c10::cuda::CUDACachingAllocator::CreateContextFn context_recorder,
270 size_t alloc_trace_max_entries,
271 bool alloc_trace_record_context) {
272 TORCH_CHECK(
273 false,
274 "CUDAPluggableAllocator does not yet support recordHistory. "
275 "If you need it, please file an issue describing your use case.");
276}
277
278void CUDAPluggableAllocator::attachOutOfMemoryObserver(
279 c10::cuda::CUDACachingAllocator::OutOfMemoryObserver observer) {
280 TORCH_CHECK(
281 false,
282 "CUDAPluggableAllocator does not yet support attachOutOfMemoryObserver. "
283 "If you need it, please file an issue describing your use case.");
284}
285
286bool CUDAPluggableAllocator::needsPoolSpecificPeerAccess() {
287 return false;
288}
289
290std::string CUDAPluggableAllocator::name() {
291 return "pluggable";
292}
293
294std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator>
295 current_custom_allocator;
296
297std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator>
298getCurrentAllocator() {
299 return current_custom_allocator;
300}
301
302// TODO: add more functions in the argument
303std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator>
304createCustomAllocator(
305 std::function<void*(size_t, int, cudaStream_t)> alloc_fn,
306 std::function<void(void*, size_t, int, cudaStream_t)> free_fn) {
307 std::shared_ptr<CUDAPluggableAllocator> allocator(
308 new CUDAPluggableAllocator(alloc_fn, free_fn));
309 allocator->init(device_count);
310 return allocator;
311}
312
313void changeCurrentAllocator(
314 std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator> allocator) {
315 TORCH_CHECK(
316 !c10::cuda::CUDACachingAllocator::allocator.load()->initialized(),
317 "Can't swap an already initialized allocator");
318 c10::cuda::CUDACachingAllocator::allocator.store(allocator.get());
319 current_custom_allocator = allocator;
320}
321
322void custom_raw_deleter(void* ptr) {
323 current_custom_allocator->raw_delete(ptr);
324}
325
326} // namespace CUDAPluggableAllocator
327} // namespace cuda
328} // namespace torch
329