1#pragma once
2
3#include <c10/core/Allocator.h>
4#include <c10/cuda/CUDAGraphsC10Utils.h>
5#include <c10/cuda/CUDAMacros.h>
6#include <c10/cuda/CUDAStream.h>
7
8#include <c10/cuda/CUDACachingAllocator.h>
9
10#include <array>
11#include <mutex>
12
13namespace torch {
14
15namespace cuda {
16
17namespace CUDAPluggableAllocator {
18
19#if defined(TORCH_HIP_VERSION)
20using streamType = c10::hip::HIPStream;
21#else
22using streamType = c10::cuda::CUDAStream;
23#endif
24
25std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator>
26getCurrentAllocator();
27std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator>
28createCustomAllocator(
29 std::function<void*(size_t, int, cudaStream_t)> alloc_fn,
30 std::function<void(void*, size_t, int, cudaStream_t)> free_fn);
31void changeCurrentAllocator(
32 std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator> allocator);
33
34struct _AllocationMetadata {
35 _AllocationMetadata();
36 _AllocationMetadata(size_t size, int device_idx, cudaStream_t stream);
37 size_t size;
38 int device_idx;
39 cudaStream_t stream;
40};
41
42struct CUDAPluggableAllocator
43 : public c10::cuda::CUDACachingAllocator::CUDAAllocator {
44 CUDAPluggableAllocator(
45 std::function<void*(size_t, int, cudaStream_t)> alloc_fn,
46 std::function<void(void*, size_t, int, cudaStream_t)> free_fn);
47
48 CUDAPluggableAllocator(CUDAPluggableAllocator& other);
49
50 void set_init_fn(std::function<void(int)> init_fn);
51
52 void set_reset_fn(std::function<void()> reset_fn);
53
54 void set_memory_fraction_fn(
55 std::function<void(double, int)> memory_fraction_fn);
56
57 void set_base_alloc_fn(std::function<void*(void*, size_t*)> base_alloc_fn);
58
59 void set_record_stream_fn(
60 std::function<void(void* ptr, cudaStream_t stream)> record_stream_fn);
61
62 void set_capture_begin_fn(
63 std::function<void(int, c10::cuda::CaptureId_t, c10::cuda::MempoolId_t)>
64 capture_begin_fn);
65
66 void set_capture_about_to_end_fn(
67 std::function<void(int, c10::cuda::CaptureId_t)> capture_about_to_end_fn);
68
69 void set_capture_ended_fn(
70 std::function<void(int, c10::cuda::CaptureId_t)> capture_ended_fn);
71
72 void set_capture_destroy_fn(
73 std::function<void(int, c10::cuda::MempoolId_t)> capture_destroy_fn);
74
75 void* malloc(size_t size, int device, cudaStream_t stream);
76
77 c10::DataPtr allocate(size_t size) const override;
78 c10::DeleterFnPtr raw_deleter() const override;
79
80 virtual void* raw_alloc(size_t nbytes) override;
81 virtual void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream)
82 override;
83 virtual void raw_delete(void* ptr) override;
84 virtual void init(int device_count) override;
85 virtual bool initialized() override;
86 virtual void setMemoryFraction(double fraction, int device) override;
87 virtual void emptyCache() override;
88 virtual void cacheInfo(int dev_id, size_t* largestBlock) override;
89 virtual void* getBaseAllocation(void* ptr, size_t* size) override;
90
91 virtual void recordStream(const c10::DataPtr&, streamType stream) override;
92
93 virtual c10::cuda::CUDACachingAllocator::DeviceStats getDeviceStats(
94 int device) override;
95 virtual void resetAccumulatedStats(int device) override;
96 virtual void resetPeakStats(int device) override;
97 virtual c10::cuda::CUDACachingAllocator::SnapshotInfo snapshot() override;
98 virtual void notifyCaptureBegin(
99 int device,
100 c10::cuda::CaptureId_t graph_id,
101 c10::cuda::MempoolId_t mempool_id) override;
102 virtual void notifyCaptureAboutToEnd(
103 int device,
104 c10::cuda::CaptureId_t graph_id) override;
105 virtual void notifyCaptureEnded(int device, c10::cuda::CaptureId_t graph_id)
106 override;
107 virtual void notifyCaptureDestroy(
108 int device,
109 c10::cuda::MempoolId_t mempool_id) override;
110 virtual std::shared_ptr<void> getIpcDevPtr(std::string handle) override;
111 virtual void recordHistory(
112 bool enabled,
113 c10::cuda::CUDACachingAllocator::CreateContextFn context_recorder,
114 size_t alloc_trace_max_entries,
115 bool alloc_trace_record_context) override;
116 virtual void attachOutOfMemoryObserver(
117 c10::cuda::CUDACachingAllocator::OutOfMemoryObserver observer) override;
118 virtual bool needsPoolSpecificPeerAccess() override;
119 virtual std::string name() override;
120
121 protected:
122 std::function<void*(size_t, int, cudaStream_t)> alloc_fn_;
123 std::function<void(void*, size_t, int, cudaStream_t)> free_fn_;
124 std::function<void(int)> init_fn_;
125 std::function<void()> reset_fn_;
126 std::function<void(double, int)> memory_fraction_fn_;
127 std::function<void*(void*, size_t*)> base_alloc_fn_;
128 std::function<void(void* ptr, cudaStream_t stream)> record_stream_fn_;
129 std::function<void(int, c10::cuda::CaptureId_t, c10::cuda::MempoolId_t)>
130 capture_begin_fn_;
131 std::function<void(int, c10::cuda::CaptureId_t)> capture_about_to_end_fn_;
132 std::function<void(int, c10::cuda::CaptureId_t)> capture_ended_fn_;
133 std::function<void(int, c10::cuda::MempoolId_t)> capture_destroy_fn_;
134 std::mutex allocator_mutex_;
135 // We do the bookeeping here in order to simplify custom allocators
136 std::unordered_map<void*, _AllocationMetadata> allocation_metadata_;
137
138 bool initialized_ = false;
139};
140} // namespace CUDAPluggableAllocator
141} // namespace cuda
142} // namespace torch
143