1#include <c10/cuda/CUDACachingAllocator.h>
2#include <c10/cuda/CUDAException.h>
3#include <c10/cuda/CUDAFunctions.h>
4#include <c10/cuda/CUDAGuard.h>
5#include <c10/util/UniqueVoidPtr.h>
6#include <c10/util/flat_hash_map.h>
7#include <c10/util/irange.h>
8
9#include <unordered_set>
10#include <vector>
11
12namespace c10 {
13namespace cuda {
14namespace CUDACachingAllocator {
15namespace CudaMallocAsync {
16
17#if CUDA_VERSION >= 11040
18// CUDA device allocator that uses cudaMallocAsync to implement
19// the same interface as CUDACachingAllocator.cpp.
20
21// Designed to be safe for CUDA graph capture.
22// Interactions with CUDA graph capture are mediated by
23// notifyCaptureBegin
24// notifyCaptureAboutToEnd
25// notifyCaptureEnded
26// notifyCaptureDestroy
27
28// Implementation details, not declared in CUDACachingAllocator.h
29namespace {
30
31// General helpers
32
33struct UsageStream {
34 cudaStream_t stream;
35 int device;
36 UsageStream() = default;
37 UsageStream(cudaStream_t s, int d) : stream(s), device(d) {}
38 UsageStream(const UsageStream& us) = default;
39 UsageStream(const UsageStream&& us) : stream(us.stream), device(us.device) {}
40 UsageStream& operator=(UsageStream other) {
41 stream = other.stream;
42 device = other.device;
43 return *this;
44 }
45};
46
47bool operator==(const UsageStream& lhs, const UsageStream& rhs) {
48 return (lhs.stream == rhs.stream) && (lhs.device == rhs.device);
49}
50
51struct UsageStreamHash {
52 size_t operator()(const UsageStream& us) const noexcept {
53 return std::hash<void*>{}(us.stream) + size_t(us.device);
54 }
55};
56
57struct PtrUsage {
58 // recorded_streams holds side usage streams added by record_stream calls.
59 // In other words, it does NOT include the original creation stream.
60 ska::flat_hash_set<UsageStream, UsageStreamHash> recorded_streams;
61 UsageStream creation_stream;
62 uint64_t size;
63 bool captured;
64 PtrUsage(uint64_t s, bool c) : size(s), captured(c) {}
65};
66
67int device_count = 0;
68// these don't need to be c10::once_flags as in CUDAGeneratorImpl.cpp
69// because they'll only be flipped by functions that have locked the mutex.
70std::vector<bool> devs_initialized_flags;
71std::vector<UsageStream> dummy_unifying_free_streams;
72
73// Possible micro-optimization:
74// Some accesses to ptr_info are read-only.
75// We could let those be concurrent with a shared_mutex and
76// have concurrent calls take a shared_lock.
77// Keeping it simple with an ordinary mutex for now.
78std::mutex general_mutex;
79
80/**
81 * Note [Avoid freeing uncaptured ptrs during CUDA graph capture]
82 * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
83 * During CUDA graph capture, it's illegal to call cudaFreeAsync
84 * on a pointer that came from a non-captured cudaMallocAsync.
85 * Unfortunately, Python being what it is, it's impossible to be
86 * sure no uncaptured tensor will ever have its destructor called
87 * in a capturing region.
88 * We avoid errors by
89 * 1. remembering if allocated pointers were captured or uncaptured
90 * 2. during capture, if we detect an attempt to free an uncaptured
91 * allocation on a capturing stream, don't free it immediately,
92 * just remember it and defer its cudaFreeAsync call to after
93 * the end of capture (specifically, to notifyCaptureEnded).
94 */
95
96using PtrInfo = ska::flat_hash_map<void*, PtrUsage>;
97PtrInfo ptr_info;
98std::vector<void*> ungraphed_ptrs_defer_free_until_no_capture;
99
100// These two help setMemoryFraction limit the amount of memory
101// used by PyTorch in particular (as opposed to other libraries
102// in the same process that might be sharing the same cudaMemPool_t).
103std::vector<size_t> pytorch_used_bytes;
104std::vector<size_t> pytorch_memory_limits;
105
106// Graph-specific helpers
107
108/**
109 * Note [Avoid dangling free streams during CUDA graph capture]
110 * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
111 * During capture, all stream dependencies must branch out from
112 * the stream on which capture began and rejoin this initial stream
113 * before capture ends.
114 * The user rigs desired forking and joining with event waits.
115 * But it's hard to be sure when tensor destructors get called relative
116 * to the final joins.
117 * For example, suppose a user
118 * forks work stream B from initial capture stream A
119 * creates a tensor T in B
120 * joins by syncing A with B
121 * ends capture.
122 * All well and good, right? Maybe not: maybe T went out of scope
123 * and its destructor got called AFTER the rejoin, leaving the graph with
124 * "unjoined work": a dangling cudaFreeAsync node in stream B.
125 * Ensuring that all tensor destructors for all side stream tensors
126 * are called before side streams rejoin the main stream is
127 * difficult. The user might have to add a bunch of explicit
128 * "del"s at the right spots in code that was fine for ordinary
129 * eager execution.
130 * Fortunately, we can spare the user this burden:
131 * during capture, we remember _all_ free streams,
132 * and manually rejoin them with the capture stream during
133 * notifyCaptureAboutToEnd.
134 * This approach is heavy-handed, but hopefully capture only needs to
135 * happen once, so we don't mind being heavy-handed.
136 *
137 * TODO: If, someday, we augment the graph bindings to support recapture
138 * https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#whole-graph-update
139 * (eg, as a way to accommodate dynamic params) we should think more
140 * carefully about the CPU overhead of remembering and rejoining
141 * all free streams during capture. Maybe it's not a big deal.
142 */
143std::unordered_set<UsageStream, UsageStreamHash> capture_free_streams;
144bool capture_underway = false;
145
146// Implementation functions
147
148// Assumes the caller holds general_mutex
149inline void lazy_init_device(int device) {
150 if (!devs_initialized_flags[device]) {
151 CUDAGuard g(device);
152
153 // See "Retaining memory in the pool" here:
154 // https://developer.nvidia.com/blog/using-cuda-stream-ordered-memory-allocator-part-1/
155 cudaMemPool_t mempool;
156 C10_CUDA_CHECK(cudaDeviceGetDefaultMemPool(&mempool, device));
157 uint64_t threshold = UINT64_MAX;
158 C10_CUDA_CHECK(cudaMemPoolSetAttribute(
159 mempool, cudaMemPoolAttrReleaseThreshold, &threshold));
160
161 // I think all these are on by default, but I want to enable them
162 // explicitly to ensure awareness.
163 int enable = 1;
164 C10_CUDA_CHECK(cudaMemPoolSetAttribute(
165 mempool, cudaMemPoolReuseFollowEventDependencies, &enable));
166 C10_CUDA_CHECK(cudaMemPoolSetAttribute(
167 mempool, cudaMemPoolReuseAllowOpportunistic, &enable));
168 C10_CUDA_CHECK(cudaMemPoolSetAttribute(
169 mempool, cudaMemPoolReuseAllowInternalDependencies, &enable));
170
171 // Grabs a stream from the current device to use as the "unifier" free
172 // stream for allocations that end up used on multiple streams.
173 const auto dufs = getStreamFromPool();
174 dummy_unifying_free_streams[device] =
175 UsageStream(dufs.stream(), dufs.device_index());
176
177 pytorch_used_bytes[device] = 0;
178 pytorch_memory_limits[device] = UINT64_MAX;
179
180 devs_initialized_flags[device] = true;
181 }
182}
183
184inline void sync_raw(cudaStream_t dependency, cudaStream_t dependent) {
185 // CUDACachingAllocator.cpp uses raw cuda events, as do we.
186 cudaEvent_t event;
187 C10_CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
188 C10_CUDA_CHECK(cudaEventRecord(event, dependency));
189 C10_CUDA_CHECK(cudaStreamWaitEvent(dependent, event));
190 C10_CUDA_CHECK(cudaEventDestroy(event));
191}
192
193// Assumes the caller holds general_mutex
194inline void free_impl(PtrInfo::iterator& it) {
195 // Possible micro-optimization: If we did a value-copy here, we could move
196 // ptr_info.erase(it) up here and drop the lock immediately.
197 const auto& recorded_streams = it->second.recorded_streams;
198 const auto& creation_stream = it->second.creation_stream;
199
200 // If the usage stream is a null (default) stream,
201 // cudaFreeAsync infers the device from the ambient context,
202 // so we need to set the right ambient context.
203 CUDAGuard g(creation_stream.device);
204
205 if (recorded_streams.empty()) {
206 // ptr was only used on one stream, which must have been
207 // the original allocation stream.
208 // Frees ptr in the original allocation stream.
209
210 C10_CUDA_CHECK(cudaFreeAsync(it->first, creation_stream.stream));
211
212 if (C10_UNLIKELY(capture_underway)) {
213 // See Note [Avoid dangling free streams during CUDA graph capture]
214 capture_free_streams.insert(creation_stream);
215 }
216 } else {
217 // ptr was used on many streams. We don't know which was the most recent.
218 // There could even have been multiple most recent usage streams acting
219 // on different regions of the memory.
220 // But cudaFreeAsync only accepts a single most recent usage stream.
221 // We can still safely free ptr with a trick:
222 // Use a dummy "unifying stream", sync the unifying stream with all of
223 // ptr's usage streams, and pass the dummy stream to cudaFreeAsync.
224
225 // Retrieves the dummy "unifier" stream from the device
226 // on which the pointer was originally allocated.
227 auto dummy_unifying_free_stream =
228 dummy_unifying_free_streams[creation_stream.device];
229 TORCH_INTERNAL_ASSERT(
230 dummy_unifying_free_stream.device == creation_stream.device);
231
232 // we're already on creation_stream.device, no need to re-guard
233 sync_raw(creation_stream.stream, dummy_unifying_free_stream.stream);
234
235 // The number of usage streams is typically small (low single digits)
236 for (const auto& recorded_stream : recorded_streams) {
237 // Logic here accommodates the chance some of the usage streams were on
238 // other devices, which is possible if some usage kernels accessed the
239 // memory via p2p.
240
241 // cudaEventRecord requires that the input event and stream are on the
242 // same device.
243 CUDAGuard g_usage(recorded_stream.device);
244
245 sync_raw(recorded_stream.stream, dummy_unifying_free_stream.stream);
246 }
247
248 // Frees ptr in the dummy "unifier" stream.
249 C10_CUDA_CHECK(cudaFreeAsync(it->first, dummy_unifying_free_stream.stream));
250 // At this point, unless dummy_unifying_free_stream happens to alias some
251 // future user stream, the allocation is only available for "opportunistic"
252 // reuse, ie, if the CPU sees dummy_unifying_free_stream has reached the
253 // point that all events recorded on all usage streams have resolved from
254 // the CPU's perspective. In theory, we could remove the need for the driver
255 // to do this tracking by e.g. replacing
256 // cudaStreamWaitEvent(dummy_unifying_free_stream.stream, event);
257 // with
258 // cudaStreamWaitEvent(creation_stream.stream, event);
259 // then cudaFreeAsyncing straight back into creation_stream.stream,
260 // but this forces a potentially false dependency of creation_stream.stream
261 // on all the recorded_streams.
262
263 if (C10_UNLIKELY(capture_underway)) {
264 // See Note [Avoid dangling free streams during CUDA graph capture]
265 capture_free_streams.emplace(
266 dummy_unifying_free_stream.stream, dummy_unifying_free_stream.device);
267 }
268 }
269
270 pytorch_used_bytes[creation_stream.device] -= it->second.size;
271
272 ptr_info.erase(it);
273}
274
275void freeAsync(void* ptr) {
276 std::lock_guard<std::mutex> lk(general_mutex);
277
278 auto err = cudaGetLastError();
279 C10_CUDA_CHECK(err);
280 auto it = ptr_info.find(ptr);
281 TORCH_INTERNAL_ASSERT(it != ptr_info.end(), "ptr not found in ptr_info");
282
283 if (C10_UNLIKELY(capture_underway)) {
284 if (!it->second.captured) {
285 TORCH_WARN_ONCE(
286 "freeAsync() was called on an uncaptured allocation during graph capture "
287 "(address = ",
288 ptr,
289 "). This may be benign, for example, a Python tensor in the capture "
290 "might happen to shadow (use the same name as) an unrelated temporary "
291 "tensor from somewhere before capture, pushing the earlier tensor "
292 "out of scope. "
293 "However, if the tensor we're freeing here IS used by the capture, "
294 "freeing it is an error, and may cause illegal memory accesses or "
295 "memory corruption during graph replay.");
296 // See Note [Avoid freeing uncaptured ptrs during CUDA graph capture]
297 // Remembers the raw pointer, not the iterator.
298 // This forces notifyCaptureEnded to do another lookup,
299 // but avoids the risk the iterator might be invalidated
300 // between now and then.
301 ungraphed_ptrs_defer_free_until_no_capture.push_back(ptr);
302 return;
303 }
304 } else if (C10_UNLIKELY(it->second.captured)) {
305 TORCH_WARN(
306 "Attempting uncaptured free of a captured allocation with address ",
307 ptr,
308 "\nThis is technically allowed, but may indicate you are losing "
309 "the last user-visible tensor through which the allocation can "
310 "be accessed, so you'll have no way to view the data after "
311 "future replays of the owning graph.");
312 }
313
314 free_impl(it);
315}
316
317// Symmetric with NativeCachingAllocator::malloc for now,
318// although I don't think we absolutely need the symmetry.
319void mallocAsync(void** devPtr, int device, size_t size, cudaStream_t stream) {
320 TORCH_INTERNAL_ASSERT(
321 0 <= device && device < device_count,
322 "Invalid device index ",
323 device,
324 ": did you call init?");
325
326 // If stream is a null (default) stream,
327 // cudaMallocAsync infers the device from the ambient context,
328 // so we need to set the right ambient context.
329 CUDAGuard g(device);
330
331 std::lock_guard<std::mutex> lk(general_mutex);
332
333 lazy_init_device(device);
334
335 // Defensively checks for preexisting CUDA error state.
336 auto err = cudaGetLastError();
337 C10_CUDA_CHECK(err);
338
339 // TODO: Could we avoid calling cudaMallocAsync while holding general_mutex,
340 // perhaps by letting lazy_init_device use separate once_flags or an internal
341 // static initializer?
342 if (pytorch_used_bytes[device] + size > pytorch_memory_limits[device]) {
343 err = cudaErrorMemoryAllocation;
344 } else {
345 err = cudaMallocAsync(devPtr, size, stream);
346 }
347
348 if (err == cudaErrorMemoryAllocation) {
349 // Clears CUDA's internal error state so the user, if desired, can catch the
350 // OOM exception, free some stuff on the script side, and retry the
351 // allocation. This aligns with the behavior of alloc_block in
352 // CUDACachingAllocator.cpp.
353 cudaGetLastError();
354 size_t device_free;
355 size_t device_total;
356 C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total));
357 TORCH_CHECK_WITH(
358 OutOfMemoryError,
359 false,
360 "Allocation on device ",
361 device,
362 " would exceed allowed memory. (out of memory)",
363 "\nCurrently allocated : ",
364 format_size(pytorch_used_bytes[device]),
365 "\nRequested : ",
366 format_size(size),
367 "\nDevice limit : ",
368 format_size(device_total),
369 "\nFree (according to CUDA): ",
370 format_size(device_free),
371 "\nPyTorch limit (set by user-supplied memory fraction)"
372 "\n : ",
373 format_size(pytorch_memory_limits[device]));
374 } else {
375 C10_CUDA_CHECK(err);
376 }
377
378 auto inserted = ptr_info.emplace(*devPtr, PtrUsage(size, capture_underway));
379 TORCH_INTERNAL_ASSERT(
380 inserted.second,
381 "address returned by cudaMallocAsync already exists "
382 "in ptr_info");
383
384 inserted.first->second.creation_stream = {stream, device};
385
386 pytorch_used_bytes[device] += size;
387}
388
389} // anonymous namespace
390
391void local_raw_delete(void* ptr);
392
393// Same pattern as CUDACachingAllocator.cpp.
394struct CudaMallocAsyncAllocator : public CUDAAllocator {
395 DataPtr allocate(size_t size) const override {
396 constexpr size_t one_exa_bytes = 1152921504606846976ULL;
397 TORCH_CHECK_WITH(
398 OutOfMemoryError,
399 size < one_exa_bytes,
400 "CUDA out of memory. Tried to allocate more than 1EB memory.");
401 int device;
402 C10_CUDA_CHECK(cudaGetDevice(&device));
403 void* r = nullptr;
404 if (size != 0) {
405 mallocAsync(&r, device, size, cuda::getCurrentCUDAStream(device));
406 }
407 return {r, r, &local_raw_delete, Device(DeviceType::CUDA, device)};
408 }
409 DeleterFnPtr raw_deleter() const override {
410 return &local_raw_delete;
411 }
412
413 // This function should not issue any context-creating calls,
414 // just set up for later calls to init per-device pools based
415 // on the current device each later call sees.
416 void init(int dev_count) override {
417 static bool called = [](int dev_count) {
418 ;
419 // Are there external guarantees init will be called before
420 // any of the allocator's other functions?
421 // std::lock_guard<std::mutex> lk(general_mutex);
422 device_count = dev_count;
423 devs_initialized_flags.resize(dev_count, false);
424 dummy_unifying_free_streams.resize(dev_count);
425 pytorch_used_bytes.resize(dev_count);
426 pytorch_memory_limits.resize(dev_count);
427 return true;
428 }(dev_count);
429 (void)called;
430 }
431
432 bool initialized() override {
433 return devs_initialized_flags.size() > 0;
434 }
435
436 static inline void assertValidDevice(int device) {
437 TORCH_CHECK(
438 0 <= device && device < device_count, "Invalid device argument.");
439 }
440
441 void setMemoryFraction(double fraction, int device) override {
442 TORCH_INTERNAL_ASSERT(
443 0 <= fraction && fraction <= 1,
444 "invalid fraction:",
445 fraction,
446 ". Please set within (0, 1).");
447
448 std::lock_guard<std::mutex> lk(general_mutex);
449 assertValidDevice(device);
450 CUDAGuard g(device);
451 // Should setMemoryFraction be allowed to trigger a full device context and
452 // pool-creating lazy_init_device, or should we simply assert this device is
453 // already initialized, ie
454 // TORCH_CHECK(devs_initialized_flags[device], ...)?
455 lazy_init_device(device);
456
457 size_t device_free;
458 size_t device_total;
459 C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total));
460 pytorch_memory_limits[device] =
461 static_cast<uint64_t>(fraction * device_total);
462
463 // Alternative: Instead of a manual hard limit, we could use
464 // cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrReleaseThreshold,
465 // &threshold); This is a soft hint: The driver allows the pool's reserved
466 // memory to spike above threshold in regions of high cudaMallocAsync
467 // demand, but opportunistically trims reserved memory back to threshold
468 // when the memory in use is < threshold. I don't like this because it
469 // introduces performance nondeterminism.
470 }
471
472 void emptyCache(void) override {
473 std::lock_guard<std::mutex> lk(general_mutex);
474
475 for (int dev = 0; dev < device_count; dev++) {
476 if (devs_initialized_flags[dev]) {
477 CUDAGuard g(dev);
478
479 cudaMemPool_t mempool;
480 cudaDeviceGetDefaultMemPool(&mempool, dev);
481 cudaDeviceSynchronize();
482 cudaMemPoolTrimTo(mempool, 0);
483 }
484 }
485 }
486
487 void cacheInfo(int device, size_t* maxWorkspaceGuess) override {
488 // The only consumer of cacheInfo is getMaxWorkspaceSize in Conv_v7.cpp.
489 // Afaict, the role of cacheInfo is to give getMaxWorkspaceSize a reasonable
490 // maximum workspace size to use for an upcoming cudnnFind call.
491 //
492 // The native allocator's cacheInfo chooses to return the size of its
493 // largest unused block (which is the largest allocation the native
494 // allocator can service immediately and asynchronously without a
495 // cudaMalloc.
496 //
497 // Here, we use a different heuristic: figure out the max usable workspace
498 // size with a bit of educated trial and error. It's ok to be
499 // perf-inefficient because cacheInfo is a prelude to cudnnFind.
500 //
501 // The algo cache then stores the best-performing algo with workspace <=
502 // maxWorkspaceGuess. Later calls with the same param set hit in cache and
503 // try to allocate the same workspace. If, in one of those future calls,
504 // workspace allocation fails (ie because less ambient memory is available),
505 // the bindings rerun cudnnFind, including calling cacheInfo again
506 // beforehand to estimate a new (smaller) largest-available workspace. Over
507 // a few such calls, the cache should settle to the algo with a workspace
508 // size that's small enough to succeed every time (for that param set).
509 //
510 // So the strategy here is to return a rough, largeish guess and let the
511 // bindings retry to trim as needed over time.
512 //
513 // The only caveat is, even if a workspace is allocated without OOM errors
514 // now and in future calls, it's hard to be sure those later error-free
515 // cudaMallocAsyncs are fast and come straight from the pool (ie,
516 // cudaMallocAsync didn't need to reserve more memory from the system).
517 // Hopefully, after repeated workspace requests, the pool's reserved memory
518 // also stabilizes to a point where they all come straight from the pool.
519 std::lock_guard<std::mutex> lk(general_mutex);
520 assertValidDevice(device);
521 CUDAGuard g(device);
522 lazy_init_device(device);
523
524 size_t free_upper_bound;
525 size_t device_total;
526 C10_CUDA_CHECK(cudaMemGetInfo(&free_upper_bound, &device_total));
527 TORCH_INTERNAL_ASSERT(
528 free_upper_bound + pytorch_used_bytes[device] <= device_total);
529 size_t guess = std::min(
530 free_upper_bound,
531 pytorch_memory_limits[device] - pytorch_used_bytes[device]);
532 auto stream = c10::cuda::getCurrentCUDAStream();
533 void* dummy;
534
535 // Defensively checks for preexisting CUDA error state.
536 auto err = cudaGetLastError();
537 C10_CUDA_CHECK(err);
538
539 while (true) {
540 // Duplicates some logic from mallocAsync to work with the error state
541 // directly instead of repeatedly catching an exception thrown by
542 // mallocAsync.
543 if (pytorch_used_bytes[device] + guess > pytorch_memory_limits[device]) {
544 err = cudaErrorMemoryAllocation;
545 } else {
546 err = cudaMallocAsync(&dummy, guess, stream);
547 }
548
549 if (err == cudaSuccess) {
550 cudaFreeAsync(dummy, stream);
551 *maxWorkspaceGuess = guess;
552 return;
553 } else if (err == cudaErrorMemoryAllocation) {
554 cudaGetLastError(); // clear CUDA error
555 guess >>= 1; // quick and dirty: try half the size next iteration
556 } else {
557 C10_CUDA_CHECK(err);
558 }
559 }
560 }
561
562 void* getBaseAllocation(void* ptr, size_t* size) override {
563 std::lock_guard<std::mutex> lk(general_mutex);
564
565 auto it = ptr_info.find(ptr);
566 TORCH_INTERNAL_ASSERT(it != ptr_info.end(), "ptr not found in ptr_info");
567
568 if (size) {
569 *size = it->second.size;
570 }
571
572 return ptr;
573 }
574
575 void recordStream(const DataPtr& ptr, cuda::CUDAStream stream) override {
576 std::lock_guard<std::mutex> lk(general_mutex);
577 auto ptr_val = ptr.get();
578 // Empty tensor's storage().data() might be a null ptr. As there is no
579 // blocks associated with those tensors, it is fine to do nothing here.
580 if (!ptr_val) {
581 return;
582 }
583
584 // The pointer should exist in the map already.
585 auto it = ptr_info.find(ptr_val);
586 TORCH_INTERNAL_ASSERT(it != ptr_info.end(), "ptr not found in ptr_info");
587
588 UsageStream to_record{stream.stream(), stream.device_index()};
589 if (to_record == it->second.creation_stream) {
590 TORCH_WARN(
591 "Called record_stream on tensor whose original creation stream "
592 "matches the recorded stream. This is unnecessary and has no effect.");
593 } else {
594 it->second.recorded_streams.insert(to_record);
595 }
596 }
597
598 std::shared_ptr<void> getIpcDevPtr(std::string handle) override {
599 TORCH_CHECK(
600 false,
601 "cudaMallocAsync does not yet support getIpcDevPtr. "
602 "If you need it, please file an issue describing your use case.");
603 }
604
605 void recordHistory(
606 bool enabled,
607 CreateContextFn context_recorder,
608 size_t alloc_trace_max_entries,
609 bool alloc_trace_record_context) override {
610 TORCH_CHECK(
611 false,
612 "cudaMallocAsync does not yet support recordHistory. "
613 "If you need it, please file an issue describing your use case.");
614 }
615
616 void attachOutOfMemoryObserver(OutOfMemoryObserver observer) override {
617 TORCH_CHECK(
618 false,
619 "cudaMallocAsync does not yet support attachOutOfMemoryObserver. "
620 "If you need it, please file an issue describing your use case.");
621 }
622
623 // Collects stats for device.
624 // If device hasn't been used yet, returns 0s without creating a context.
625 DeviceStats getDeviceStats(int device) override {
626 assertValidDevice(device);
627
628 // Memory currently reserved by the mempool
629 uint64_t reserved_mem_current = 0;
630 // High-water mark of memory reserved by the mempool since last reset
631 uint64_t reserved_mem_peak = 0;
632 // Memory currently in use by the mempool
633 uint64_t used_mem_current = 0;
634 // High-water mark of memory
635 uint64_t used_mem_peak = 0;
636
637 std::lock_guard<std::mutex> lk(general_mutex);
638
639 if (devs_initialized_flags[device]) {
640 CUDAGuard g(device);
641
642 cudaMemPool_t mempool;
643 C10_CUDA_CHECK(cudaDeviceGetDefaultMemPool(&mempool, device));
644 C10_CUDA_CHECK(cudaMemPoolGetAttribute(
645 mempool, cudaMemPoolAttrReservedMemCurrent, &reserved_mem_current));
646
647 C10_CUDA_CHECK(cudaMemPoolGetAttribute(
648 mempool, cudaMemPoolAttrReservedMemHigh, &reserved_mem_peak));
649
650 C10_CUDA_CHECK(cudaMemPoolGetAttribute(
651 mempool, cudaMemPoolAttrUsedMemCurrent, &used_mem_current));
652
653 C10_CUDA_CHECK(cudaMemPoolGetAttribute(
654 mempool, cudaMemPoolAttrUsedMemHigh, &used_mem_peak));
655 }
656
657 // Many stat types are specific to the native allocator. We leave these
658 // untouched. Their "struct Stat"s will contain zeroed values.
659 DeviceStats stats;
660
661 // In the native allocator:
662 // allocated_bytes is the total bytes of blocks that have been malloc()ed
663 // and not yet free()d.
664 // active_bytes is the total bytes of blocks that have been malloc()ed but
665 // not yet released back into a free pool. In other words, it includes all
666 // allocated_bytes, as well as the bytes of "limbo state" blocks had have
667 // already been free()ed but not yet free_block()ed back into a pool due to
668 // outstanding stream_uses.
669 //
670 // Here, in the cudaMallocAsync allocator:
671 // We simply ask the driver's opinion about active memory.
672 // We don't bother distinguishing between allocated_bytes and active_bytes.
673 stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)].current =
674 used_mem_current;
675 stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)].peak =
676 used_mem_peak;
677 stats.active_bytes[static_cast<size_t>(StatType::AGGREGATE)].current =
678 used_mem_current;
679 stats.active_bytes[static_cast<size_t>(StatType::AGGREGATE)].peak =
680 used_mem_peak;
681 stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)].current =
682 reserved_mem_current;
683 stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)].peak =
684 reserved_mem_peak;
685
686 return stats;
687 }
688
689 void resetAccumulatedStats(int device) override {
690 assertValidDevice(device);
691 TORCH_WARN_ONCE(
692 "For backend:cudaMallocAsync, resetAccumulatedStats has no effect.");
693 }
694
695 void resetPeakStats(int device) override {
696 assertValidDevice(device);
697
698 CUDAGuard g(device);
699 cudaMemPool_t mempool;
700 C10_CUDA_CHECK(cudaDeviceGetDefaultMemPool(&mempool, device));
701 // Using zero as the reset value is the method recommended by Cuda driver
702 // team. Vivek Kini says:
703 // "Resetting to zero (which is the only valid value when setting
704 // ReservedMemHigh) resets it to ReservedMemCurrent inside the driver
705 // (same goes for UsedMemHigh/UsedMemCurrent)"
706 uint64_t zero = 0;
707 C10_CUDA_CHECK(cudaMemPoolSetAttribute(
708 mempool, cudaMemPoolAttrReservedMemHigh, &zero));
709 C10_CUDA_CHECK(
710 cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrUsedMemHigh, &zero));
711 }
712
713 SnapshotInfo snapshot() override {
714 TORCH_CHECK(
715 false,
716 "Calling snapshot with backend:cudaMallocAsync is not meaningful. "
717 "(For backend:native, snapshot returns a detailed summary of all "
718 "blocks tracked by the allocator, but the cudaMallocAsync backend "
719 "does not track individual blocks.)");
720 // Alternative: TORCH_WARN
721 return {};
722 }
723
724 // CUDAGraph interactions
725 void notifyCaptureBegin(
726 int device,
727 CaptureId_t graph_id,
728 MempoolId_t mempool_id) override {
729 std::lock_guard<std::mutex> lk(general_mutex);
730
731 TORCH_INTERNAL_ASSERT(capture_free_streams.empty());
732 TORCH_CHECK(
733 !capture_underway,
734 "Only one capture at a time is allowed in a process.")
735 capture_underway = true;
736 }
737
738 void notifyCaptureAboutToEnd(int device, CaptureId_t graph_id) override {
739 assertValidDevice(device);
740
741 std::lock_guard<std::mutex> lk(general_mutex);
742
743 TORCH_CHECK(
744 capture_underway,
745 "CudaMallocAsync::notifyCaptureAboutToEnd called, "
746 "but CudaMallocAsync::capture_underway is false.");
747
748 auto capture_stream = cuda::getCurrentCUDAStream(device);
749
750 // See Note [Avoid dangling free streams during CUDA graph capture]
751 for (const auto& free_stream : capture_free_streams) {
752 // cudaEventRecord requires that the input event and stream are on the
753 // same device.
754 CUDAGuard g(free_stream.device);
755
756 // CUDACachingAllocator.cpp uses raw cuda events, as do we.
757 cudaEvent_t event;
758 C10_CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
759 C10_CUDA_CHECK(cudaEventRecord(event, free_stream.stream));
760 C10_CUDA_CHECK(cudaStreamWaitEvent(capture_stream.stream(), event));
761 C10_CUDA_CHECK(cudaEventDestroy(event));
762 }
763
764 capture_free_streams.clear();
765 }
766
767 void notifyCaptureEnded(int device, CaptureId_t graph_id) override {
768 assertValidDevice(device);
769
770 std::lock_guard<std::mutex> lk(general_mutex);
771
772 TORCH_CHECK(
773 capture_underway,
774 "CudaMallocAsync::notifyCaptureEnded called, "
775 "but CudaMallocAsync::capture_underway is false.");
776 capture_underway = false;
777
778 // See Note [Avoid freeing uncaptured ptrs during CUDA graph capture]
779 for (const auto ptr : ungraphed_ptrs_defer_free_until_no_capture) {
780 auto it = ptr_info.find(ptr);
781 TORCH_INTERNAL_ASSERT(it != ptr_info.end(), "ptr not found in ptr_info");
782 free_impl(it);
783 }
784
785 ungraphed_ptrs_defer_free_until_no_capture.clear();
786 }
787
788 void notifyCaptureDestroy(int device, MempoolId_t mempool_id) override {
789 // Q: Do we need to do anything special here, like clear long-lived
790 // pointers created during the original capture (for example,
791 // tensors intended as the graph's I/O surface) that might still
792 // be resident in ptr_info?
793 // A: I don't think so.
794 // Those allocations survived capture because the user held
795 // explicit tensor references to them,
796 // Those tensors' destructors will call freeAsync() on each pointer
797 // when the user is done with them.
798 // The freeAsync()s will probably incur
799 // TORCH_WARN("Attempting uncaptured free of a captured allocation..."
800 // but stale ptrs will not permanently leak into ptr_info.
801 }
802
803 void* raw_alloc(size_t nbytes) override {
804 if (nbytes == 0) {
805 return nullptr;
806 }
807 int device;
808 C10_CUDA_CHECK(cudaGetDevice(&device));
809 void* r = nullptr;
810 mallocAsync(&r, device, nbytes, cuda::getCurrentCUDAStream(device));
811 return r;
812 }
813
814 void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) override {
815 if (nbytes == 0) {
816 return nullptr;
817 }
818 int device;
819 C10_CUDA_CHECK(cudaGetDevice(&device));
820 void* r = nullptr;
821 mallocAsync(&r, device, nbytes, stream);
822 return r;
823 }
824 void raw_delete(void* ptr) override {
825 freeAsync(ptr);
826 }
827 bool needsPoolSpecificPeerAccess() override {
828 return true;
829 }
830 std::string name() override {
831 return "cudaMallocAsync";
832 }
833};
834
835CudaMallocAsyncAllocator device_allocator;
836
837void local_raw_delete(void* ptr) {
838 freeAsync(ptr);
839}
840CUDAAllocator* allocator() {
841 return &device_allocator;
842}
843
844#else
845CUDAAllocator* allocator() {
846 TORCH_CHECK(false, "Cannot use cudaMallocAsyncAllocator with cuda < 11.4.");
847 return nullptr;
848}
849
850#endif
851
852} // namespace CudaMallocAsync
853} // namespace CUDACachingAllocator
854} // namespace cuda
855} // namespace c10
856