1 | #include <mutex> |
2 | #include <unordered_map> |
3 | #include <utility> |
4 | |
5 | #include <torch/csrc/cuda/CUDAPluggableAllocator.h> |
6 | |
7 | namespace torch { |
8 | namespace cuda { |
9 | namespace CUDAPluggableAllocator { |
10 | |
11 | int device_count = 0; |
12 | |
13 | void custom_raw_deleter(void* ptr); |
14 | |
15 | _AllocationMetadata::_AllocationMetadata() |
16 | : size(0), device_idx(-1), stream(0) {} |
17 | |
18 | _AllocationMetadata::_AllocationMetadata( |
19 | size_t size, |
20 | int device_idx, |
21 | cudaStream_t stream) |
22 | : size(size), device_idx(device_idx), stream(stream) {} |
23 | |
24 | // This is a fast API to just register allocators |
25 | // based on function pointers (ie. external .so libraries) |
26 | // This avoids having to link against libtorch for C++ based custom allocators |
27 | // And also use this from python |
28 | CUDAPluggableAllocator::CUDAPluggableAllocator( |
29 | std::function<void*(size_t, int, cudaStream_t)> alloc_fn, |
30 | std::function<void(void*, size_t, int, cudaStream_t)> free_fn) |
31 | : alloc_fn_(alloc_fn), free_fn_(free_fn) {} |
32 | |
33 | CUDAPluggableAllocator::CUDAPluggableAllocator(CUDAPluggableAllocator& other) |
34 | : alloc_fn_(other.alloc_fn_), |
35 | free_fn_(other.free_fn_), |
36 | init_fn_(other.init_fn_), |
37 | reset_fn_(other.reset_fn_), |
38 | memory_fraction_fn_(other.memory_fraction_fn_), |
39 | base_alloc_fn_(other.base_alloc_fn_), |
40 | record_stream_fn_(other.record_stream_fn_), |
41 | capture_begin_fn_(other.capture_begin_fn_), |
42 | capture_about_to_end_fn_(other.capture_about_to_end_fn_), |
43 | capture_ended_fn_(other.capture_ended_fn_), |
44 | capture_destroy_fn_(other.capture_destroy_fn_) {} |
45 | |
46 | void CUDAPluggableAllocator::set_init_fn(std::function<void(int)> init_fn) { |
47 | init_fn_ = init_fn; |
48 | } |
49 | |
50 | void CUDAPluggableAllocator::set_reset_fn(std::function<void()> reset_fn) { |
51 | reset_fn_ = reset_fn; |
52 | } |
53 | |
54 | void CUDAPluggableAllocator::set_memory_fraction_fn( |
55 | std::function<void(double, int)> memory_fraction_fn) { |
56 | memory_fraction_fn_ = memory_fraction_fn; |
57 | } |
58 | |
59 | void CUDAPluggableAllocator::set_base_alloc_fn( |
60 | std::function<void*(void*, size_t*)> base_alloc_fn) { |
61 | base_alloc_fn_ = base_alloc_fn; |
62 | } |
63 | |
64 | void CUDAPluggableAllocator::set_record_stream_fn( |
65 | std::function<void(void* ptr, cudaStream_t stream)> record_stream_fn) { |
66 | record_stream_fn_ = record_stream_fn; |
67 | } |
68 | |
69 | void CUDAPluggableAllocator::set_capture_begin_fn( |
70 | std::function<void(int, c10::cuda::CaptureId_t, c10::cuda::MempoolId_t)> |
71 | capture_begin_fn) { |
72 | capture_begin_fn_ = capture_begin_fn; |
73 | } |
74 | |
75 | void CUDAPluggableAllocator::set_capture_about_to_end_fn( |
76 | std::function<void(int, c10::cuda::CaptureId_t)> capture_about_to_end_fn) { |
77 | capture_about_to_end_fn_ = capture_about_to_end_fn; |
78 | } |
79 | |
80 | void CUDAPluggableAllocator::set_capture_ended_fn( |
81 | std::function<void(int, c10::cuda::CaptureId_t)> capture_ended_fn) { |
82 | capture_ended_fn_ = capture_ended_fn; |
83 | } |
84 | |
85 | void CUDAPluggableAllocator::set_capture_destroy_fn( |
86 | std::function<void(int, c10::cuda::MempoolId_t)> capture_destroy_fn) { |
87 | capture_destroy_fn_ = capture_destroy_fn; |
88 | } |
89 | |
90 | void* CUDAPluggableAllocator::malloc( |
91 | size_t size, |
92 | int device, |
93 | cudaStream_t stream) { |
94 | void* r = alloc_fn_(size, device, stream); |
95 | { |
96 | const std::lock_guard<std::mutex> lock(allocator_mutex_); |
97 | allocation_metadata_.emplace(r, _AllocationMetadata(size, device, stream)); |
98 | } |
99 | return r; |
100 | } |
101 | |
102 | c10::DataPtr CUDAPluggableAllocator::allocate(size_t size) const { |
103 | int device; |
104 | C10_CUDA_CHECK(cudaGetDevice(&device)); |
105 | cudaStream_t stream = c10::cuda::getCurrentCUDAStream(device); |
106 | void* r = |
107 | const_cast<CUDAPluggableAllocator*>(this)->malloc(size, device, stream); |
108 | c10::DataPtr data_ptr = { |
109 | r, r, raw_deleter(), c10::Device(c10::DeviceType::CUDA, device)}; |
110 | return data_ptr; |
111 | } |
112 | |
113 | c10::DeleterFnPtr CUDAPluggableAllocator::raw_deleter() const { |
114 | return &custom_raw_deleter; |
115 | } |
116 | |
117 | void* CUDAPluggableAllocator::raw_alloc(size_t nbytes) { |
118 | int device; |
119 | C10_CUDA_CHECK(cudaGetDevice(&device)); |
120 | cudaStream_t stream = c10::cuda::getCurrentCUDAStream(device); |
121 | return malloc(nbytes, device, stream); |
122 | } |
123 | |
124 | void* CUDAPluggableAllocator::raw_alloc_with_stream( |
125 | size_t nbytes, |
126 | cudaStream_t stream) { |
127 | int device; |
128 | C10_CUDA_CHECK(cudaGetDevice(&device)); |
129 | return malloc(nbytes, device, stream); |
130 | } |
131 | |
132 | void CUDAPluggableAllocator::raw_delete(void* ptr) { |
133 | cudaStream_t stream; |
134 | int device_idx; |
135 | size_t size; |
136 | { |
137 | const std::lock_guard<std::mutex> lock(allocator_mutex_); |
138 | TORCH_CHECK( |
139 | allocation_metadata_.count(ptr), |
140 | "Trying to free a pointer not allocated here" ); |
141 | _AllocationMetadata& metadata = allocation_metadata_[ptr]; |
142 | size = metadata.size; |
143 | device_idx = metadata.device_idx; |
144 | stream = metadata.stream; |
145 | allocation_metadata_.erase(ptr); |
146 | } |
147 | free_fn_(ptr, size, device_idx, stream); |
148 | } |
149 | |
150 | void CUDAPluggableAllocator::init(int device_count) { |
151 | if (init_fn_) { |
152 | init_fn_(device_count); |
153 | } |
154 | initialized_ = true; |
155 | } |
156 | |
157 | bool CUDAPluggableAllocator::initialized() { |
158 | return initialized_; |
159 | } |
160 | |
161 | void CUDAPluggableAllocator::setMemoryFraction(double fraction, int device) { |
162 | if (memory_fraction_fn_) { |
163 | memory_fraction_fn_(fraction, device); |
164 | } |
165 | } |
166 | |
167 | void CUDAPluggableAllocator::emptyCache(void) { |
168 | if (reset_fn_) { |
169 | return reset_fn_(); |
170 | } |
171 | } |
172 | |
173 | void CUDAPluggableAllocator::cacheInfo(int dev_id, size_t* largestBlock) { |
174 | TORCH_CHECK( |
175 | false, |
176 | "CUDAPluggableAllocator does not yet support cacheInfo. " |
177 | "If you need it, please file an issue describing your use case." ); |
178 | } |
179 | |
180 | void* CUDAPluggableAllocator::getBaseAllocation(void* ptr, size_t* size) { |
181 | if (base_alloc_fn_) { |
182 | return base_alloc_fn_(ptr, size); |
183 | } else { |
184 | return ptr; |
185 | } |
186 | } |
187 | |
188 | void CUDAPluggableAllocator::recordStream( |
189 | const c10::DataPtr& ptr, |
190 | streamType stream) { |
191 | if (record_stream_fn_) { |
192 | record_stream_fn_(ptr.get(), stream); |
193 | } |
194 | } |
195 | |
196 | c10::cuda::CUDACachingAllocator::DeviceStats CUDAPluggableAllocator:: |
197 | getDeviceStats(int device) { |
198 | TORCH_CHECK( |
199 | false, |
200 | "CUDAPluggableAllocator does not yet support getDeviceStats. " |
201 | "If you need it, please file an issue describing your use case." ); |
202 | } |
203 | |
204 | void CUDAPluggableAllocator::resetAccumulatedStats(int device) { |
205 | TORCH_CHECK( |
206 | false, |
207 | "CUDAPluggableAllocator does not yet support resetAccumulatedStats. " |
208 | "If you need it, please file an issue describing your use case." ); |
209 | } |
210 | |
211 | void CUDAPluggableAllocator::resetPeakStats(int device) { |
212 | TORCH_CHECK( |
213 | false, |
214 | "CUDAPluggableAllocator does not yet support resetPeakStats. " |
215 | "If you need it, please file an issue describing your use case." ); |
216 | } |
217 | |
218 | c10::cuda::CUDACachingAllocator::SnapshotInfo CUDAPluggableAllocator:: |
219 | snapshot() { |
220 | TORCH_CHECK( |
221 | false, |
222 | "CUDAPluggableAllocator does not yet support snapshot. " |
223 | "If you need it, please file an issue describing your use case." ); |
224 | } |
225 | |
226 | std::shared_ptr<void> CUDAPluggableAllocator::getIpcDevPtr(std::string handle) { |
227 | TORCH_CHECK( |
228 | false, |
229 | "CUDAPluggableAllocator does not yet support getIpcDevPtr. " |
230 | "If you need it, please file an issue describing your use case." ); |
231 | } |
232 | |
233 | // CUDAGraph interactions |
234 | void CUDAPluggableAllocator::notifyCaptureBegin( |
235 | int device, |
236 | c10::cuda::CaptureId_t graph_id, |
237 | c10::cuda::MempoolId_t mempool_id) { |
238 | if (capture_begin_fn_) { |
239 | capture_begin_fn_(device, graph_id, mempool_id); |
240 | } |
241 | } |
242 | |
243 | void CUDAPluggableAllocator::notifyCaptureAboutToEnd( |
244 | int device, |
245 | c10::cuda::CaptureId_t graph_id) { |
246 | if (capture_about_to_end_fn_) { |
247 | capture_about_to_end_fn_(device, graph_id); |
248 | } |
249 | } |
250 | |
251 | void CUDAPluggableAllocator::notifyCaptureEnded( |
252 | int device, |
253 | c10::cuda::CaptureId_t graph_id) { |
254 | if (capture_ended_fn_) { |
255 | capture_ended_fn_(device, graph_id); |
256 | } |
257 | } |
258 | |
259 | void CUDAPluggableAllocator::notifyCaptureDestroy( |
260 | int device, |
261 | c10::cuda::MempoolId_t mempool_id) { |
262 | if (capture_destroy_fn_) { |
263 | capture_destroy_fn_(device, mempool_id); |
264 | } |
265 | } |
266 | |
267 | void CUDAPluggableAllocator::recordHistory( |
268 | bool enabled, |
269 | c10::cuda::CUDACachingAllocator::CreateContextFn context_recorder, |
270 | size_t alloc_trace_max_entries, |
271 | bool alloc_trace_record_context) { |
272 | TORCH_CHECK( |
273 | false, |
274 | "CUDAPluggableAllocator does not yet support recordHistory. " |
275 | "If you need it, please file an issue describing your use case." ); |
276 | } |
277 | |
278 | void CUDAPluggableAllocator::attachOutOfMemoryObserver( |
279 | c10::cuda::CUDACachingAllocator::OutOfMemoryObserver observer) { |
280 | TORCH_CHECK( |
281 | false, |
282 | "CUDAPluggableAllocator does not yet support attachOutOfMemoryObserver. " |
283 | "If you need it, please file an issue describing your use case." ); |
284 | } |
285 | |
286 | bool CUDAPluggableAllocator::needsPoolSpecificPeerAccess() { |
287 | return false; |
288 | } |
289 | |
290 | std::string CUDAPluggableAllocator::name() { |
291 | return "pluggable" ; |
292 | } |
293 | |
294 | std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator> |
295 | current_custom_allocator; |
296 | |
297 | std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator> |
298 | getCurrentAllocator() { |
299 | return current_custom_allocator; |
300 | } |
301 | |
302 | // TODO: add more functions in the argument |
303 | std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator> |
304 | createCustomAllocator( |
305 | std::function<void*(size_t, int, cudaStream_t)> alloc_fn, |
306 | std::function<void(void*, size_t, int, cudaStream_t)> free_fn) { |
307 | std::shared_ptr<CUDAPluggableAllocator> allocator( |
308 | new CUDAPluggableAllocator(alloc_fn, free_fn)); |
309 | allocator->init(device_count); |
310 | return allocator; |
311 | } |
312 | |
313 | void changeCurrentAllocator( |
314 | std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator> allocator) { |
315 | TORCH_CHECK( |
316 | !c10::cuda::CUDACachingAllocator::allocator.load()->initialized(), |
317 | "Can't swap an already initialized allocator" ); |
318 | c10::cuda::CUDACachingAllocator::allocator.store(allocator.get()); |
319 | current_custom_allocator = allocator; |
320 | } |
321 | |
322 | void custom_raw_deleter(void* ptr) { |
323 | current_custom_allocator->raw_delete(ptr); |
324 | } |
325 | |
326 | } // namespace CUDAPluggableAllocator |
327 | } // namespace cuda |
328 | } // namespace torch |
329 | |