1 | #include <c10/cuda/CUDACachingAllocator.h> |
2 | #include <c10/cuda/CUDAException.h> |
3 | #include <c10/cuda/CUDAFunctions.h> |
4 | #include <c10/cuda/CUDAGuard.h> |
5 | #include <c10/util/UniqueVoidPtr.h> |
6 | #include <c10/util/flat_hash_map.h> |
7 | #include <c10/util/irange.h> |
8 | |
9 | #include <unordered_set> |
10 | #include <vector> |
11 | |
12 | namespace c10 { |
13 | namespace cuda { |
14 | namespace CUDACachingAllocator { |
15 | namespace CudaMallocAsync { |
16 | |
17 | #if CUDA_VERSION >= 11040 |
18 | // CUDA device allocator that uses cudaMallocAsync to implement |
19 | // the same interface as CUDACachingAllocator.cpp. |
20 | |
21 | // Designed to be safe for CUDA graph capture. |
22 | // Interactions with CUDA graph capture are mediated by |
23 | // notifyCaptureBegin |
24 | // notifyCaptureAboutToEnd |
25 | // notifyCaptureEnded |
26 | // notifyCaptureDestroy |
27 | |
28 | // Implementation details, not declared in CUDACachingAllocator.h |
29 | namespace { |
30 | |
31 | // General helpers |
32 | |
33 | struct UsageStream { |
34 | cudaStream_t stream; |
35 | int device; |
36 | UsageStream() = default; |
37 | UsageStream(cudaStream_t s, int d) : stream(s), device(d) {} |
38 | UsageStream(const UsageStream& us) = default; |
39 | UsageStream(const UsageStream&& us) : stream(us.stream), device(us.device) {} |
40 | UsageStream& operator=(UsageStream other) { |
41 | stream = other.stream; |
42 | device = other.device; |
43 | return *this; |
44 | } |
45 | }; |
46 | |
47 | bool operator==(const UsageStream& lhs, const UsageStream& rhs) { |
48 | return (lhs.stream == rhs.stream) && (lhs.device == rhs.device); |
49 | } |
50 | |
51 | struct UsageStreamHash { |
52 | size_t operator()(const UsageStream& us) const noexcept { |
53 | return std::hash<void*>{}(us.stream) + size_t(us.device); |
54 | } |
55 | }; |
56 | |
57 | struct PtrUsage { |
58 | // recorded_streams holds side usage streams added by record_stream calls. |
59 | // In other words, it does NOT include the original creation stream. |
60 | ska::flat_hash_set<UsageStream, UsageStreamHash> recorded_streams; |
61 | UsageStream creation_stream; |
62 | uint64_t size; |
63 | bool captured; |
64 | PtrUsage(uint64_t s, bool c) : size(s), captured(c) {} |
65 | }; |
66 | |
67 | int device_count = 0; |
68 | // these don't need to be c10::once_flags as in CUDAGeneratorImpl.cpp |
69 | // because they'll only be flipped by functions that have locked the mutex. |
70 | std::vector<bool> devs_initialized_flags; |
71 | std::vector<UsageStream> dummy_unifying_free_streams; |
72 | |
73 | // Possible micro-optimization: |
74 | // Some accesses to ptr_info are read-only. |
75 | // We could let those be concurrent with a shared_mutex and |
76 | // have concurrent calls take a shared_lock. |
77 | // Keeping it simple with an ordinary mutex for now. |
78 | std::mutex general_mutex; |
79 | |
80 | /** |
81 | * Note [Avoid freeing uncaptured ptrs during CUDA graph capture] |
82 | * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
83 | * During CUDA graph capture, it's illegal to call cudaFreeAsync |
84 | * on a pointer that came from a non-captured cudaMallocAsync. |
85 | * Unfortunately, Python being what it is, it's impossible to be |
86 | * sure no uncaptured tensor will ever have its destructor called |
87 | * in a capturing region. |
88 | * We avoid errors by |
89 | * 1. remembering if allocated pointers were captured or uncaptured |
90 | * 2. during capture, if we detect an attempt to free an uncaptured |
91 | * allocation on a capturing stream, don't free it immediately, |
92 | * just remember it and defer its cudaFreeAsync call to after |
93 | * the end of capture (specifically, to notifyCaptureEnded). |
94 | */ |
95 | |
96 | using PtrInfo = ska::flat_hash_map<void*, PtrUsage>; |
97 | PtrInfo ptr_info; |
98 | std::vector<void*> ungraphed_ptrs_defer_free_until_no_capture; |
99 | |
100 | // These two help setMemoryFraction limit the amount of memory |
101 | // used by PyTorch in particular (as opposed to other libraries |
102 | // in the same process that might be sharing the same cudaMemPool_t). |
103 | std::vector<size_t> pytorch_used_bytes; |
104 | std::vector<size_t> pytorch_memory_limits; |
105 | |
106 | // Graph-specific helpers |
107 | |
108 | /** |
109 | * Note [Avoid dangling free streams during CUDA graph capture] |
110 | * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
111 | * During capture, all stream dependencies must branch out from |
112 | * the stream on which capture began and rejoin this initial stream |
113 | * before capture ends. |
114 | * The user rigs desired forking and joining with event waits. |
115 | * But it's hard to be sure when tensor destructors get called relative |
116 | * to the final joins. |
117 | * For example, suppose a user |
118 | * forks work stream B from initial capture stream A |
119 | * creates a tensor T in B |
120 | * joins by syncing A with B |
121 | * ends capture. |
122 | * All well and good, right? Maybe not: maybe T went out of scope |
123 | * and its destructor got called AFTER the rejoin, leaving the graph with |
124 | * "unjoined work": a dangling cudaFreeAsync node in stream B. |
125 | * Ensuring that all tensor destructors for all side stream tensors |
126 | * are called before side streams rejoin the main stream is |
127 | * difficult. The user might have to add a bunch of explicit |
128 | * "del"s at the right spots in code that was fine for ordinary |
129 | * eager execution. |
130 | * Fortunately, we can spare the user this burden: |
131 | * during capture, we remember _all_ free streams, |
132 | * and manually rejoin them with the capture stream during |
133 | * notifyCaptureAboutToEnd. |
134 | * This approach is heavy-handed, but hopefully capture only needs to |
135 | * happen once, so we don't mind being heavy-handed. |
136 | * |
137 | * TODO: If, someday, we augment the graph bindings to support recapture |
138 | * https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#whole-graph-update |
139 | * (eg, as a way to accommodate dynamic params) we should think more |
140 | * carefully about the CPU overhead of remembering and rejoining |
141 | * all free streams during capture. Maybe it's not a big deal. |
142 | */ |
143 | std::unordered_set<UsageStream, UsageStreamHash> capture_free_streams; |
144 | bool capture_underway = false; |
145 | |
146 | // Implementation functions |
147 | |
148 | // Assumes the caller holds general_mutex |
149 | inline void lazy_init_device(int device) { |
150 | if (!devs_initialized_flags[device]) { |
151 | CUDAGuard g(device); |
152 | |
153 | // See "Retaining memory in the pool" here: |
154 | // https://developer.nvidia.com/blog/using-cuda-stream-ordered-memory-allocator-part-1/ |
155 | cudaMemPool_t mempool; |
156 | C10_CUDA_CHECK(cudaDeviceGetDefaultMemPool(&mempool, device)); |
157 | uint64_t threshold = UINT64_MAX; |
158 | C10_CUDA_CHECK(cudaMemPoolSetAttribute( |
159 | mempool, cudaMemPoolAttrReleaseThreshold, &threshold)); |
160 | |
161 | // I think all these are on by default, but I want to enable them |
162 | // explicitly to ensure awareness. |
163 | int enable = 1; |
164 | C10_CUDA_CHECK(cudaMemPoolSetAttribute( |
165 | mempool, cudaMemPoolReuseFollowEventDependencies, &enable)); |
166 | C10_CUDA_CHECK(cudaMemPoolSetAttribute( |
167 | mempool, cudaMemPoolReuseAllowOpportunistic, &enable)); |
168 | C10_CUDA_CHECK(cudaMemPoolSetAttribute( |
169 | mempool, cudaMemPoolReuseAllowInternalDependencies, &enable)); |
170 | |
171 | // Grabs a stream from the current device to use as the "unifier" free |
172 | // stream for allocations that end up used on multiple streams. |
173 | const auto dufs = getStreamFromPool(); |
174 | dummy_unifying_free_streams[device] = |
175 | UsageStream(dufs.stream(), dufs.device_index()); |
176 | |
177 | pytorch_used_bytes[device] = 0; |
178 | pytorch_memory_limits[device] = UINT64_MAX; |
179 | |
180 | devs_initialized_flags[device] = true; |
181 | } |
182 | } |
183 | |
184 | inline void sync_raw(cudaStream_t dependency, cudaStream_t dependent) { |
185 | // CUDACachingAllocator.cpp uses raw cuda events, as do we. |
186 | cudaEvent_t event; |
187 | C10_CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming)); |
188 | C10_CUDA_CHECK(cudaEventRecord(event, dependency)); |
189 | C10_CUDA_CHECK(cudaStreamWaitEvent(dependent, event)); |
190 | C10_CUDA_CHECK(cudaEventDestroy(event)); |
191 | } |
192 | |
193 | // Assumes the caller holds general_mutex |
194 | inline void free_impl(PtrInfo::iterator& it) { |
195 | // Possible micro-optimization: If we did a value-copy here, we could move |
196 | // ptr_info.erase(it) up here and drop the lock immediately. |
197 | const auto& recorded_streams = it->second.recorded_streams; |
198 | const auto& creation_stream = it->second.creation_stream; |
199 | |
200 | // If the usage stream is a null (default) stream, |
201 | // cudaFreeAsync infers the device from the ambient context, |
202 | // so we need to set the right ambient context. |
203 | CUDAGuard g(creation_stream.device); |
204 | |
205 | if (recorded_streams.empty()) { |
206 | // ptr was only used on one stream, which must have been |
207 | // the original allocation stream. |
208 | // Frees ptr in the original allocation stream. |
209 | |
210 | C10_CUDA_CHECK(cudaFreeAsync(it->first, creation_stream.stream)); |
211 | |
212 | if (C10_UNLIKELY(capture_underway)) { |
213 | // See Note [Avoid dangling free streams during CUDA graph capture] |
214 | capture_free_streams.insert(creation_stream); |
215 | } |
216 | } else { |
217 | // ptr was used on many streams. We don't know which was the most recent. |
218 | // There could even have been multiple most recent usage streams acting |
219 | // on different regions of the memory. |
220 | // But cudaFreeAsync only accepts a single most recent usage stream. |
221 | // We can still safely free ptr with a trick: |
222 | // Use a dummy "unifying stream", sync the unifying stream with all of |
223 | // ptr's usage streams, and pass the dummy stream to cudaFreeAsync. |
224 | |
225 | // Retrieves the dummy "unifier" stream from the device |
226 | // on which the pointer was originally allocated. |
227 | auto dummy_unifying_free_stream = |
228 | dummy_unifying_free_streams[creation_stream.device]; |
229 | TORCH_INTERNAL_ASSERT( |
230 | dummy_unifying_free_stream.device == creation_stream.device); |
231 | |
232 | // we're already on creation_stream.device, no need to re-guard |
233 | sync_raw(creation_stream.stream, dummy_unifying_free_stream.stream); |
234 | |
235 | // The number of usage streams is typically small (low single digits) |
236 | for (const auto& recorded_stream : recorded_streams) { |
237 | // Logic here accommodates the chance some of the usage streams were on |
238 | // other devices, which is possible if some usage kernels accessed the |
239 | // memory via p2p. |
240 | |
241 | // cudaEventRecord requires that the input event and stream are on the |
242 | // same device. |
243 | CUDAGuard g_usage(recorded_stream.device); |
244 | |
245 | sync_raw(recorded_stream.stream, dummy_unifying_free_stream.stream); |
246 | } |
247 | |
248 | // Frees ptr in the dummy "unifier" stream. |
249 | C10_CUDA_CHECK(cudaFreeAsync(it->first, dummy_unifying_free_stream.stream)); |
250 | // At this point, unless dummy_unifying_free_stream happens to alias some |
251 | // future user stream, the allocation is only available for "opportunistic" |
252 | // reuse, ie, if the CPU sees dummy_unifying_free_stream has reached the |
253 | // point that all events recorded on all usage streams have resolved from |
254 | // the CPU's perspective. In theory, we could remove the need for the driver |
255 | // to do this tracking by e.g. replacing |
256 | // cudaStreamWaitEvent(dummy_unifying_free_stream.stream, event); |
257 | // with |
258 | // cudaStreamWaitEvent(creation_stream.stream, event); |
259 | // then cudaFreeAsyncing straight back into creation_stream.stream, |
260 | // but this forces a potentially false dependency of creation_stream.stream |
261 | // on all the recorded_streams. |
262 | |
263 | if (C10_UNLIKELY(capture_underway)) { |
264 | // See Note [Avoid dangling free streams during CUDA graph capture] |
265 | capture_free_streams.emplace( |
266 | dummy_unifying_free_stream.stream, dummy_unifying_free_stream.device); |
267 | } |
268 | } |
269 | |
270 | pytorch_used_bytes[creation_stream.device] -= it->second.size; |
271 | |
272 | ptr_info.erase(it); |
273 | } |
274 | |
275 | void freeAsync(void* ptr) { |
276 | std::lock_guard<std::mutex> lk(general_mutex); |
277 | |
278 | auto err = cudaGetLastError(); |
279 | C10_CUDA_CHECK(err); |
280 | auto it = ptr_info.find(ptr); |
281 | TORCH_INTERNAL_ASSERT(it != ptr_info.end(), "ptr not found in ptr_info" ); |
282 | |
283 | if (C10_UNLIKELY(capture_underway)) { |
284 | if (!it->second.captured) { |
285 | TORCH_WARN_ONCE( |
286 | "freeAsync() was called on an uncaptured allocation during graph capture " |
287 | "(address = " , |
288 | ptr, |
289 | "). This may be benign, for example, a Python tensor in the capture " |
290 | "might happen to shadow (use the same name as) an unrelated temporary " |
291 | "tensor from somewhere before capture, pushing the earlier tensor " |
292 | "out of scope. " |
293 | "However, if the tensor we're freeing here IS used by the capture, " |
294 | "freeing it is an error, and may cause illegal memory accesses or " |
295 | "memory corruption during graph replay." ); |
296 | // See Note [Avoid freeing uncaptured ptrs during CUDA graph capture] |
297 | // Remembers the raw pointer, not the iterator. |
298 | // This forces notifyCaptureEnded to do another lookup, |
299 | // but avoids the risk the iterator might be invalidated |
300 | // between now and then. |
301 | ungraphed_ptrs_defer_free_until_no_capture.push_back(ptr); |
302 | return; |
303 | } |
304 | } else if (C10_UNLIKELY(it->second.captured)) { |
305 | TORCH_WARN( |
306 | "Attempting uncaptured free of a captured allocation with address " , |
307 | ptr, |
308 | "\nThis is technically allowed, but may indicate you are losing " |
309 | "the last user-visible tensor through which the allocation can " |
310 | "be accessed, so you'll have no way to view the data after " |
311 | "future replays of the owning graph." ); |
312 | } |
313 | |
314 | free_impl(it); |
315 | } |
316 | |
317 | // Symmetric with NativeCachingAllocator::malloc for now, |
318 | // although I don't think we absolutely need the symmetry. |
319 | void mallocAsync(void** devPtr, int device, size_t size, cudaStream_t stream) { |
320 | TORCH_INTERNAL_ASSERT( |
321 | 0 <= device && device < device_count, |
322 | "Invalid device index " , |
323 | device, |
324 | ": did you call init?" ); |
325 | |
326 | // If stream is a null (default) stream, |
327 | // cudaMallocAsync infers the device from the ambient context, |
328 | // so we need to set the right ambient context. |
329 | CUDAGuard g(device); |
330 | |
331 | std::lock_guard<std::mutex> lk(general_mutex); |
332 | |
333 | lazy_init_device(device); |
334 | |
335 | // Defensively checks for preexisting CUDA error state. |
336 | auto err = cudaGetLastError(); |
337 | C10_CUDA_CHECK(err); |
338 | |
339 | // TODO: Could we avoid calling cudaMallocAsync while holding general_mutex, |
340 | // perhaps by letting lazy_init_device use separate once_flags or an internal |
341 | // static initializer? |
342 | if (pytorch_used_bytes[device] + size > pytorch_memory_limits[device]) { |
343 | err = cudaErrorMemoryAllocation; |
344 | } else { |
345 | err = cudaMallocAsync(devPtr, size, stream); |
346 | } |
347 | |
348 | if (err == cudaErrorMemoryAllocation) { |
349 | // Clears CUDA's internal error state so the user, if desired, can catch the |
350 | // OOM exception, free some stuff on the script side, and retry the |
351 | // allocation. This aligns with the behavior of alloc_block in |
352 | // CUDACachingAllocator.cpp. |
353 | cudaGetLastError(); |
354 | size_t device_free; |
355 | size_t device_total; |
356 | C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total)); |
357 | TORCH_CHECK_WITH( |
358 | OutOfMemoryError, |
359 | false, |
360 | "Allocation on device " , |
361 | device, |
362 | " would exceed allowed memory. (out of memory)" , |
363 | "\nCurrently allocated : " , |
364 | format_size(pytorch_used_bytes[device]), |
365 | "\nRequested : " , |
366 | format_size(size), |
367 | "\nDevice limit : " , |
368 | format_size(device_total), |
369 | "\nFree (according to CUDA): " , |
370 | format_size(device_free), |
371 | "\nPyTorch limit (set by user-supplied memory fraction)" |
372 | "\n : " , |
373 | format_size(pytorch_memory_limits[device])); |
374 | } else { |
375 | C10_CUDA_CHECK(err); |
376 | } |
377 | |
378 | auto inserted = ptr_info.emplace(*devPtr, PtrUsage(size, capture_underway)); |
379 | TORCH_INTERNAL_ASSERT( |
380 | inserted.second, |
381 | "address returned by cudaMallocAsync already exists " |
382 | "in ptr_info" ); |
383 | |
384 | inserted.first->second.creation_stream = {stream, device}; |
385 | |
386 | pytorch_used_bytes[device] += size; |
387 | } |
388 | |
389 | } // anonymous namespace |
390 | |
391 | void local_raw_delete(void* ptr); |
392 | |
393 | // Same pattern as CUDACachingAllocator.cpp. |
394 | struct CudaMallocAsyncAllocator : public CUDAAllocator { |
395 | DataPtr allocate(size_t size) const override { |
396 | constexpr size_t one_exa_bytes = 1152921504606846976ULL; |
397 | TORCH_CHECK_WITH( |
398 | OutOfMemoryError, |
399 | size < one_exa_bytes, |
400 | "CUDA out of memory. Tried to allocate more than 1EB memory." ); |
401 | int device; |
402 | C10_CUDA_CHECK(cudaGetDevice(&device)); |
403 | void* r = nullptr; |
404 | if (size != 0) { |
405 | mallocAsync(&r, device, size, cuda::getCurrentCUDAStream(device)); |
406 | } |
407 | return {r, r, &local_raw_delete, Device(DeviceType::CUDA, device)}; |
408 | } |
409 | DeleterFnPtr raw_deleter() const override { |
410 | return &local_raw_delete; |
411 | } |
412 | |
413 | // This function should not issue any context-creating calls, |
414 | // just set up for later calls to init per-device pools based |
415 | // on the current device each later call sees. |
416 | void init(int dev_count) override { |
417 | static bool called = [](int dev_count) { |
418 | ; |
419 | // Are there external guarantees init will be called before |
420 | // any of the allocator's other functions? |
421 | // std::lock_guard<std::mutex> lk(general_mutex); |
422 | device_count = dev_count; |
423 | devs_initialized_flags.resize(dev_count, false); |
424 | dummy_unifying_free_streams.resize(dev_count); |
425 | pytorch_used_bytes.resize(dev_count); |
426 | pytorch_memory_limits.resize(dev_count); |
427 | return true; |
428 | }(dev_count); |
429 | (void)called; |
430 | } |
431 | |
432 | bool initialized() override { |
433 | return devs_initialized_flags.size() > 0; |
434 | } |
435 | |
436 | static inline void assertValidDevice(int device) { |
437 | TORCH_CHECK( |
438 | 0 <= device && device < device_count, "Invalid device argument." ); |
439 | } |
440 | |
441 | void setMemoryFraction(double fraction, int device) override { |
442 | TORCH_INTERNAL_ASSERT( |
443 | 0 <= fraction && fraction <= 1, |
444 | "invalid fraction:" , |
445 | fraction, |
446 | ". Please set within (0, 1)." ); |
447 | |
448 | std::lock_guard<std::mutex> lk(general_mutex); |
449 | assertValidDevice(device); |
450 | CUDAGuard g(device); |
451 | // Should setMemoryFraction be allowed to trigger a full device context and |
452 | // pool-creating lazy_init_device, or should we simply assert this device is |
453 | // already initialized, ie |
454 | // TORCH_CHECK(devs_initialized_flags[device], ...)? |
455 | lazy_init_device(device); |
456 | |
457 | size_t device_free; |
458 | size_t device_total; |
459 | C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total)); |
460 | pytorch_memory_limits[device] = |
461 | static_cast<uint64_t>(fraction * device_total); |
462 | |
463 | // Alternative: Instead of a manual hard limit, we could use |
464 | // cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrReleaseThreshold, |
465 | // &threshold); This is a soft hint: The driver allows the pool's reserved |
466 | // memory to spike above threshold in regions of high cudaMallocAsync |
467 | // demand, but opportunistically trims reserved memory back to threshold |
468 | // when the memory in use is < threshold. I don't like this because it |
469 | // introduces performance nondeterminism. |
470 | } |
471 | |
472 | void emptyCache(void) override { |
473 | std::lock_guard<std::mutex> lk(general_mutex); |
474 | |
475 | for (int dev = 0; dev < device_count; dev++) { |
476 | if (devs_initialized_flags[dev]) { |
477 | CUDAGuard g(dev); |
478 | |
479 | cudaMemPool_t mempool; |
480 | cudaDeviceGetDefaultMemPool(&mempool, dev); |
481 | cudaDeviceSynchronize(); |
482 | cudaMemPoolTrimTo(mempool, 0); |
483 | } |
484 | } |
485 | } |
486 | |
487 | void cacheInfo(int device, size_t* maxWorkspaceGuess) override { |
488 | // The only consumer of cacheInfo is getMaxWorkspaceSize in Conv_v7.cpp. |
489 | // Afaict, the role of cacheInfo is to give getMaxWorkspaceSize a reasonable |
490 | // maximum workspace size to use for an upcoming cudnnFind call. |
491 | // |
492 | // The native allocator's cacheInfo chooses to return the size of its |
493 | // largest unused block (which is the largest allocation the native |
494 | // allocator can service immediately and asynchronously without a |
495 | // cudaMalloc. |
496 | // |
497 | // Here, we use a different heuristic: figure out the max usable workspace |
498 | // size with a bit of educated trial and error. It's ok to be |
499 | // perf-inefficient because cacheInfo is a prelude to cudnnFind. |
500 | // |
501 | // The algo cache then stores the best-performing algo with workspace <= |
502 | // maxWorkspaceGuess. Later calls with the same param set hit in cache and |
503 | // try to allocate the same workspace. If, in one of those future calls, |
504 | // workspace allocation fails (ie because less ambient memory is available), |
505 | // the bindings rerun cudnnFind, including calling cacheInfo again |
506 | // beforehand to estimate a new (smaller) largest-available workspace. Over |
507 | // a few such calls, the cache should settle to the algo with a workspace |
508 | // size that's small enough to succeed every time (for that param set). |
509 | // |
510 | // So the strategy here is to return a rough, largeish guess and let the |
511 | // bindings retry to trim as needed over time. |
512 | // |
513 | // The only caveat is, even if a workspace is allocated without OOM errors |
514 | // now and in future calls, it's hard to be sure those later error-free |
515 | // cudaMallocAsyncs are fast and come straight from the pool (ie, |
516 | // cudaMallocAsync didn't need to reserve more memory from the system). |
517 | // Hopefully, after repeated workspace requests, the pool's reserved memory |
518 | // also stabilizes to a point where they all come straight from the pool. |
519 | std::lock_guard<std::mutex> lk(general_mutex); |
520 | assertValidDevice(device); |
521 | CUDAGuard g(device); |
522 | lazy_init_device(device); |
523 | |
524 | size_t free_upper_bound; |
525 | size_t device_total; |
526 | C10_CUDA_CHECK(cudaMemGetInfo(&free_upper_bound, &device_total)); |
527 | TORCH_INTERNAL_ASSERT( |
528 | free_upper_bound + pytorch_used_bytes[device] <= device_total); |
529 | size_t guess = std::min( |
530 | free_upper_bound, |
531 | pytorch_memory_limits[device] - pytorch_used_bytes[device]); |
532 | auto stream = c10::cuda::getCurrentCUDAStream(); |
533 | void* dummy; |
534 | |
535 | // Defensively checks for preexisting CUDA error state. |
536 | auto err = cudaGetLastError(); |
537 | C10_CUDA_CHECK(err); |
538 | |
539 | while (true) { |
540 | // Duplicates some logic from mallocAsync to work with the error state |
541 | // directly instead of repeatedly catching an exception thrown by |
542 | // mallocAsync. |
543 | if (pytorch_used_bytes[device] + guess > pytorch_memory_limits[device]) { |
544 | err = cudaErrorMemoryAllocation; |
545 | } else { |
546 | err = cudaMallocAsync(&dummy, guess, stream); |
547 | } |
548 | |
549 | if (err == cudaSuccess) { |
550 | cudaFreeAsync(dummy, stream); |
551 | *maxWorkspaceGuess = guess; |
552 | return; |
553 | } else if (err == cudaErrorMemoryAllocation) { |
554 | cudaGetLastError(); // clear CUDA error |
555 | guess >>= 1; // quick and dirty: try half the size next iteration |
556 | } else { |
557 | C10_CUDA_CHECK(err); |
558 | } |
559 | } |
560 | } |
561 | |
562 | void* getBaseAllocation(void* ptr, size_t* size) override { |
563 | std::lock_guard<std::mutex> lk(general_mutex); |
564 | |
565 | auto it = ptr_info.find(ptr); |
566 | TORCH_INTERNAL_ASSERT(it != ptr_info.end(), "ptr not found in ptr_info" ); |
567 | |
568 | if (size) { |
569 | *size = it->second.size; |
570 | } |
571 | |
572 | return ptr; |
573 | } |
574 | |
575 | void recordStream(const DataPtr& ptr, cuda::CUDAStream stream) override { |
576 | std::lock_guard<std::mutex> lk(general_mutex); |
577 | auto ptr_val = ptr.get(); |
578 | // Empty tensor's storage().data() might be a null ptr. As there is no |
579 | // blocks associated with those tensors, it is fine to do nothing here. |
580 | if (!ptr_val) { |
581 | return; |
582 | } |
583 | |
584 | // The pointer should exist in the map already. |
585 | auto it = ptr_info.find(ptr_val); |
586 | TORCH_INTERNAL_ASSERT(it != ptr_info.end(), "ptr not found in ptr_info" ); |
587 | |
588 | UsageStream to_record{stream.stream(), stream.device_index()}; |
589 | if (to_record == it->second.creation_stream) { |
590 | TORCH_WARN( |
591 | "Called record_stream on tensor whose original creation stream " |
592 | "matches the recorded stream. This is unnecessary and has no effect." ); |
593 | } else { |
594 | it->second.recorded_streams.insert(to_record); |
595 | } |
596 | } |
597 | |
598 | std::shared_ptr<void> getIpcDevPtr(std::string handle) override { |
599 | TORCH_CHECK( |
600 | false, |
601 | "cudaMallocAsync does not yet support getIpcDevPtr. " |
602 | "If you need it, please file an issue describing your use case." ); |
603 | } |
604 | |
605 | void recordHistory( |
606 | bool enabled, |
607 | CreateContextFn context_recorder, |
608 | size_t alloc_trace_max_entries, |
609 | bool alloc_trace_record_context) override { |
610 | TORCH_CHECK( |
611 | false, |
612 | "cudaMallocAsync does not yet support recordHistory. " |
613 | "If you need it, please file an issue describing your use case." ); |
614 | } |
615 | |
616 | void attachOutOfMemoryObserver(OutOfMemoryObserver observer) override { |
617 | TORCH_CHECK( |
618 | false, |
619 | "cudaMallocAsync does not yet support attachOutOfMemoryObserver. " |
620 | "If you need it, please file an issue describing your use case." ); |
621 | } |
622 | |
623 | // Collects stats for device. |
624 | // If device hasn't been used yet, returns 0s without creating a context. |
625 | DeviceStats getDeviceStats(int device) override { |
626 | assertValidDevice(device); |
627 | |
628 | // Memory currently reserved by the mempool |
629 | uint64_t reserved_mem_current = 0; |
630 | // High-water mark of memory reserved by the mempool since last reset |
631 | uint64_t reserved_mem_peak = 0; |
632 | // Memory currently in use by the mempool |
633 | uint64_t used_mem_current = 0; |
634 | // High-water mark of memory |
635 | uint64_t used_mem_peak = 0; |
636 | |
637 | std::lock_guard<std::mutex> lk(general_mutex); |
638 | |
639 | if (devs_initialized_flags[device]) { |
640 | CUDAGuard g(device); |
641 | |
642 | cudaMemPool_t mempool; |
643 | C10_CUDA_CHECK(cudaDeviceGetDefaultMemPool(&mempool, device)); |
644 | C10_CUDA_CHECK(cudaMemPoolGetAttribute( |
645 | mempool, cudaMemPoolAttrReservedMemCurrent, &reserved_mem_current)); |
646 | |
647 | C10_CUDA_CHECK(cudaMemPoolGetAttribute( |
648 | mempool, cudaMemPoolAttrReservedMemHigh, &reserved_mem_peak)); |
649 | |
650 | C10_CUDA_CHECK(cudaMemPoolGetAttribute( |
651 | mempool, cudaMemPoolAttrUsedMemCurrent, &used_mem_current)); |
652 | |
653 | C10_CUDA_CHECK(cudaMemPoolGetAttribute( |
654 | mempool, cudaMemPoolAttrUsedMemHigh, &used_mem_peak)); |
655 | } |
656 | |
657 | // Many stat types are specific to the native allocator. We leave these |
658 | // untouched. Their "struct Stat"s will contain zeroed values. |
659 | DeviceStats stats; |
660 | |
661 | // In the native allocator: |
662 | // allocated_bytes is the total bytes of blocks that have been malloc()ed |
663 | // and not yet free()d. |
664 | // active_bytes is the total bytes of blocks that have been malloc()ed but |
665 | // not yet released back into a free pool. In other words, it includes all |
666 | // allocated_bytes, as well as the bytes of "limbo state" blocks had have |
667 | // already been free()ed but not yet free_block()ed back into a pool due to |
668 | // outstanding stream_uses. |
669 | // |
670 | // Here, in the cudaMallocAsync allocator: |
671 | // We simply ask the driver's opinion about active memory. |
672 | // We don't bother distinguishing between allocated_bytes and active_bytes. |
673 | stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)].current = |
674 | used_mem_current; |
675 | stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)].peak = |
676 | used_mem_peak; |
677 | stats.active_bytes[static_cast<size_t>(StatType::AGGREGATE)].current = |
678 | used_mem_current; |
679 | stats.active_bytes[static_cast<size_t>(StatType::AGGREGATE)].peak = |
680 | used_mem_peak; |
681 | stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)].current = |
682 | reserved_mem_current; |
683 | stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)].peak = |
684 | reserved_mem_peak; |
685 | |
686 | return stats; |
687 | } |
688 | |
689 | void resetAccumulatedStats(int device) override { |
690 | assertValidDevice(device); |
691 | TORCH_WARN_ONCE( |
692 | "For backend:cudaMallocAsync, resetAccumulatedStats has no effect." ); |
693 | } |
694 | |
695 | void resetPeakStats(int device) override { |
696 | assertValidDevice(device); |
697 | |
698 | CUDAGuard g(device); |
699 | cudaMemPool_t mempool; |
700 | C10_CUDA_CHECK(cudaDeviceGetDefaultMemPool(&mempool, device)); |
701 | // Using zero as the reset value is the method recommended by Cuda driver |
702 | // team. Vivek Kini says: |
703 | // "Resetting to zero (which is the only valid value when setting |
704 | // ReservedMemHigh) resets it to ReservedMemCurrent inside the driver |
705 | // (same goes for UsedMemHigh/UsedMemCurrent)" |
706 | uint64_t zero = 0; |
707 | C10_CUDA_CHECK(cudaMemPoolSetAttribute( |
708 | mempool, cudaMemPoolAttrReservedMemHigh, &zero)); |
709 | C10_CUDA_CHECK( |
710 | cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrUsedMemHigh, &zero)); |
711 | } |
712 | |
713 | SnapshotInfo snapshot() override { |
714 | TORCH_CHECK( |
715 | false, |
716 | "Calling snapshot with backend:cudaMallocAsync is not meaningful. " |
717 | "(For backend:native, snapshot returns a detailed summary of all " |
718 | "blocks tracked by the allocator, but the cudaMallocAsync backend " |
719 | "does not track individual blocks.)" ); |
720 | // Alternative: TORCH_WARN |
721 | return {}; |
722 | } |
723 | |
724 | // CUDAGraph interactions |
725 | void notifyCaptureBegin( |
726 | int device, |
727 | CaptureId_t graph_id, |
728 | MempoolId_t mempool_id) override { |
729 | std::lock_guard<std::mutex> lk(general_mutex); |
730 | |
731 | TORCH_INTERNAL_ASSERT(capture_free_streams.empty()); |
732 | TORCH_CHECK( |
733 | !capture_underway, |
734 | "Only one capture at a time is allowed in a process." ) |
735 | capture_underway = true; |
736 | } |
737 | |
738 | void notifyCaptureAboutToEnd(int device, CaptureId_t graph_id) override { |
739 | assertValidDevice(device); |
740 | |
741 | std::lock_guard<std::mutex> lk(general_mutex); |
742 | |
743 | TORCH_CHECK( |
744 | capture_underway, |
745 | "CudaMallocAsync::notifyCaptureAboutToEnd called, " |
746 | "but CudaMallocAsync::capture_underway is false." ); |
747 | |
748 | auto capture_stream = cuda::getCurrentCUDAStream(device); |
749 | |
750 | // See Note [Avoid dangling free streams during CUDA graph capture] |
751 | for (const auto& free_stream : capture_free_streams) { |
752 | // cudaEventRecord requires that the input event and stream are on the |
753 | // same device. |
754 | CUDAGuard g(free_stream.device); |
755 | |
756 | // CUDACachingAllocator.cpp uses raw cuda events, as do we. |
757 | cudaEvent_t event; |
758 | C10_CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming)); |
759 | C10_CUDA_CHECK(cudaEventRecord(event, free_stream.stream)); |
760 | C10_CUDA_CHECK(cudaStreamWaitEvent(capture_stream.stream(), event)); |
761 | C10_CUDA_CHECK(cudaEventDestroy(event)); |
762 | } |
763 | |
764 | capture_free_streams.clear(); |
765 | } |
766 | |
767 | void notifyCaptureEnded(int device, CaptureId_t graph_id) override { |
768 | assertValidDevice(device); |
769 | |
770 | std::lock_guard<std::mutex> lk(general_mutex); |
771 | |
772 | TORCH_CHECK( |
773 | capture_underway, |
774 | "CudaMallocAsync::notifyCaptureEnded called, " |
775 | "but CudaMallocAsync::capture_underway is false." ); |
776 | capture_underway = false; |
777 | |
778 | // See Note [Avoid freeing uncaptured ptrs during CUDA graph capture] |
779 | for (const auto ptr : ungraphed_ptrs_defer_free_until_no_capture) { |
780 | auto it = ptr_info.find(ptr); |
781 | TORCH_INTERNAL_ASSERT(it != ptr_info.end(), "ptr not found in ptr_info" ); |
782 | free_impl(it); |
783 | } |
784 | |
785 | ungraphed_ptrs_defer_free_until_no_capture.clear(); |
786 | } |
787 | |
788 | void notifyCaptureDestroy(int device, MempoolId_t mempool_id) override { |
789 | // Q: Do we need to do anything special here, like clear long-lived |
790 | // pointers created during the original capture (for example, |
791 | // tensors intended as the graph's I/O surface) that might still |
792 | // be resident in ptr_info? |
793 | // A: I don't think so. |
794 | // Those allocations survived capture because the user held |
795 | // explicit tensor references to them, |
796 | // Those tensors' destructors will call freeAsync() on each pointer |
797 | // when the user is done with them. |
798 | // The freeAsync()s will probably incur |
799 | // TORCH_WARN("Attempting uncaptured free of a captured allocation..." |
800 | // but stale ptrs will not permanently leak into ptr_info. |
801 | } |
802 | |
803 | void* raw_alloc(size_t nbytes) override { |
804 | if (nbytes == 0) { |
805 | return nullptr; |
806 | } |
807 | int device; |
808 | C10_CUDA_CHECK(cudaGetDevice(&device)); |
809 | void* r = nullptr; |
810 | mallocAsync(&r, device, nbytes, cuda::getCurrentCUDAStream(device)); |
811 | return r; |
812 | } |
813 | |
814 | void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) override { |
815 | if (nbytes == 0) { |
816 | return nullptr; |
817 | } |
818 | int device; |
819 | C10_CUDA_CHECK(cudaGetDevice(&device)); |
820 | void* r = nullptr; |
821 | mallocAsync(&r, device, nbytes, stream); |
822 | return r; |
823 | } |
824 | void raw_delete(void* ptr) override { |
825 | freeAsync(ptr); |
826 | } |
827 | bool needsPoolSpecificPeerAccess() override { |
828 | return true; |
829 | } |
830 | std::string name() override { |
831 | return "cudaMallocAsync" ; |
832 | } |
833 | }; |
834 | |
835 | CudaMallocAsyncAllocator device_allocator; |
836 | |
837 | void local_raw_delete(void* ptr) { |
838 | freeAsync(ptr); |
839 | } |
840 | CUDAAllocator* allocator() { |
841 | return &device_allocator; |
842 | } |
843 | |
844 | #else |
845 | CUDAAllocator* allocator() { |
846 | TORCH_CHECK(false, "Cannot use cudaMallocAsyncAllocator with cuda < 11.4." ); |
847 | return nullptr; |
848 | } |
849 | |
850 | #endif |
851 | |
852 | } // namespace CudaMallocAsync |
853 | } // namespace CUDACachingAllocator |
854 | } // namespace cuda |
855 | } // namespace c10 |
856 | |