1#pragma once
2
3#include <c10/core/Allocator.h>
4#include <c10/cuda/CUDAGraphsC10Utils.h>
5#include <c10/cuda/CUDAMacros.h>
6#include <c10/cuda/CUDAStream.h>
7#include <c10/util/Registry.h>
8
9#include <array>
10#include <mutex>
11
12namespace c10 {
13
14// Caching allocator will execute every registered callback if it unable to find
15// block inside of already allocated area.
16class C10_CUDA_API FreeMemoryCallback {
17 public:
18 virtual ~FreeMemoryCallback() = default;
19 virtual bool Execute() = 0;
20};
21
22C10_DECLARE_REGISTRY(FreeCudaMemoryCallbacksRegistry, FreeMemoryCallback);
23#define REGISTER_FREE_MEMORY_CALLBACK(name, ...) \
24 C10_REGISTER_CLASS(FreeCudaMemoryCallbacksRegistry, name, __VA_ARGS__);
25
26namespace cuda {
27
28// TODO: Turn this into an honest to goodness class. I briefly attempted to do
29// this, but it was a bit irritating to figure out how to also correctly
30// apply pimpl pattern so I didn't have to leak any internal implementation
31// details in the header (CUDACachingAllocator could be made a pimpl, but
32// you also need to appropriately define a class which is a subclass
33// of Allocator. Not impossible, but required a bit more surgery than
34// I wanted to do at the time.)
35//
36// Why is this using a namespace rather than old-style THCCachingAllocator_
37// prefix? Mostly because it made the HIPify rules easier to write; _ is
38// not counted as a word boundary, so you would otherwise have to list each
39// of these functions.
40
41namespace CUDACachingAllocator {
42
43struct Stat {
44 int64_t current = 0;
45 int64_t peak = 0;
46 int64_t allocated = 0;
47 int64_t freed = 0;
48};
49
50enum struct StatType : uint64_t {
51 AGGREGATE = 0,
52 SMALL_POOL = 1,
53 LARGE_POOL = 2,
54 NUM_TYPES = 3 // remember to update this whenever a new stat type is added
55};
56
57typedef std::array<Stat, static_cast<size_t>(StatType::NUM_TYPES)> StatArray;
58
59// Struct containing memory allocator summary statistics for a device.
60struct DeviceStats {
61 // COUNT: allocations requested by client code
62 StatArray allocation;
63 // COUNT: number of allocated segments from cudaMalloc().
64 StatArray segment;
65 // COUNT: number of active memory blocks (allocated or used by stream)
66 StatArray active;
67 // COUNT: number of inactive, split memory blocks (unallocated but can't be
68 // released via cudaFree)
69 StatArray inactive_split;
70
71 // SUM: bytes allocated by this memory alocator
72 StatArray allocated_bytes;
73 // SUM: bytes reserved by this memory allocator (both free and used)
74 StatArray reserved_bytes;
75 // SUM: bytes within active memory blocks
76 StatArray active_bytes;
77 // SUM: bytes within inactive, split memory blocks
78 StatArray inactive_split_bytes;
79 // SUM: bytes requested by client code
80 StatArray requested_bytes;
81
82 // COUNT: total number of failed calls to CUDA malloc necessitating cache
83 // flushes.
84 int64_t num_alloc_retries = 0;
85
86 // COUNT: total number of OOMs (i.e. failed calls to CUDA after cache flush)
87 int64_t num_ooms = 0;
88
89 // COUNT: total number of oversize blocks allocated from pool
90 Stat oversize_allocations;
91
92 // COUNT: total number of oversize blocks requiring malloc
93 Stat oversize_segments;
94
95 // SIZE: maximum block size that is allowed to be split.
96 int64_t max_split_size = 0;
97};
98
99struct Context {
100 virtual ~Context() = default;
101};
102
103typedef std::shared_ptr<Context> (*CreateContextFn)(void);
104
105struct History {
106 void* addr;
107 size_t real_size; // unrounded, actually requested size
108 std::shared_ptr<Context> context; // per-watcher context
109};
110
111// Struct containing info of an allocation block (i.e. a fractional part of a
112// cudaMalloc)..
113struct BlockInfo {
114 int64_t size = 0;
115 int64_t requested_size = 0;
116 int32_t gc_counter = 0;
117 bool allocated = false;
118 bool active = false;
119 std::vector<History> history;
120};
121
122// Struct containing info of a memory segment (i.e. one contiguous cudaMalloc).
123struct SegmentInfo {
124 int64_t device = 0;
125 int64_t address = 0;
126 int64_t total_size = 0;
127 int64_t requested_size = 0;
128 int64_t allocated_size = 0;
129 int64_t active_size = 0;
130 cudaStream_t stream = 0;
131 bool is_large = false;
132 std::vector<BlockInfo> blocks;
133};
134
135struct TraceEntry {
136 enum Action {
137 ALLOC, // API made to the caching allocator for new memory
138 FREE_REQUESTED, // API call made to the caching allocator to free memory
139 FREE_COMPLETED, // The allocator might have to delay a free because
140 // it is still in use on another stream via record_stream
141 // This event is generated when a free actually completes.
142 SEGMENT_ALLOC, // a call to cudaMalloc to get more memory from the OS
143 SEGMENT_FREE, // a call to cudaFree to return memory to the OS (e.g. to
144 // defragement or empty_caches)
145 SNAPSHOT, // a call to snapshot, used to correlate memory snapshots to trace
146 // events
147 OOM // the allocator threw an OutOfMemoryError (addr_ is the amount of free
148 // bytes reported by cuda)
149 };
150 TraceEntry(
151 Action action,
152 int64_t addr,
153 size_t size,
154 cudaStream_t stream,
155 std::shared_ptr<Context> context = nullptr)
156 : action_(action),
157 addr_(addr),
158 context_(context),
159 stream_(stream),
160 size_(size) {}
161 Action action_;
162 int64_t addr_; // for OOM, this is the amount of free bytes reported by cuda
163 std::shared_ptr<Context> context_;
164 cudaStream_t stream_;
165 int64_t size_;
166};
167
168struct SnapshotInfo {
169 std::vector<SegmentInfo> segments;
170 std::vector<std::vector<TraceEntry>> device_traces;
171};
172
173C10_CUDA_API void setAllocatorSettings(const std::string& env);
174
175// Size pretty-printer
176std::string format_size(uint64_t size);
177
178using OutOfMemoryObserver = std::function<void(
179 int64_t device,
180 int64_t allocated,
181 int64_t device_total,
182 int64_t device_free)>;
183
184class CUDAAllocator : public Allocator {
185 public:
186 virtual void* raw_alloc(size_t nbytes) = 0;
187 virtual void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) = 0;
188 virtual void raw_delete(void* ptr) = 0;
189 virtual void init(int device_count) = 0;
190 virtual bool initialized() = 0;
191 virtual void setMemoryFraction(double fraction, int device) = 0;
192 virtual void emptyCache() = 0;
193 virtual void cacheInfo(int dev_id, size_t* largestBlock) = 0;
194 virtual void* getBaseAllocation(void* ptr, size_t* size) = 0;
195 virtual void recordStream(const DataPtr&, CUDAStream stream) = 0;
196 virtual DeviceStats getDeviceStats(int device) = 0;
197 virtual void resetAccumulatedStats(int device) = 0;
198 virtual void resetPeakStats(int device) = 0;
199 virtual SnapshotInfo snapshot() = 0;
200 virtual void notifyCaptureBegin(
201 int device,
202 CaptureId_t graph_id,
203 MempoolId_t mempool_id) = 0;
204 virtual void notifyCaptureAboutToEnd(int device, CaptureId_t graph_id) = 0;
205 virtual void notifyCaptureEnded(int device, CaptureId_t graph_id) = 0;
206 virtual void notifyCaptureDestroy(int device, MempoolId_t mempool_id) = 0;
207 virtual std::shared_ptr<void> getIpcDevPtr(std::string handle) = 0;
208 virtual void recordHistory(
209 bool enabled,
210 CreateContextFn context_recorder,
211 size_t alloc_trace_max_entries,
212 bool alloc_trace_record_context) = 0;
213 virtual void attachOutOfMemoryObserver(OutOfMemoryObserver observer) = 0;
214 virtual bool needsPoolSpecificPeerAccess() = 0;
215 virtual std::string name() = 0;
216};
217
218// Allocator object, statically initialized
219// See BackendInitializer in CUDACachingAllocator.cpp.
220// Atomic loads on x86 are just normal loads,
221// (atomic stores are different), so reading this value
222// is no different than loading a pointer.
223C10_CUDA_API extern std::atomic<CUDAAllocator*> allocator;
224
225inline CUDAAllocator* get() {
226 return allocator.load();
227}
228
229// Called directly by clients.
230inline void* raw_alloc(size_t nbytes) {
231 return get()->raw_alloc(nbytes);
232}
233
234inline void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) {
235 return get()->raw_alloc_with_stream(nbytes, stream);
236}
237
238inline void raw_delete(void* ptr) {
239 return get()->raw_delete(ptr);
240}
241
242inline void init(int device_count) {
243 return get()->init(device_count);
244}
245
246inline void setMemoryFraction(double fraction, int device) {
247 return get()->setMemoryFraction(fraction, device);
248}
249
250inline void emptyCache() {
251 return get()->emptyCache();
252}
253
254inline void cacheInfo(int dev_id, size_t* largestBlock) {
255 return get()->cacheInfo(dev_id, largestBlock);
256}
257
258inline void* getBaseAllocation(void* ptr, size_t* size) {
259 return get()->getBaseAllocation(ptr, size);
260}
261
262inline void recordStream(const DataPtr& dataPtr, CUDAStream stream) {
263 return get()->recordStream(dataPtr, stream);
264}
265
266inline DeviceStats getDeviceStats(int device) {
267 return get()->getDeviceStats(device);
268}
269
270inline void resetAccumulatedStats(int device) {
271 return get()->resetAccumulatedStats(device);
272}
273
274inline void resetPeakStats(int device) {
275 return get()->resetPeakStats(device);
276}
277
278inline SnapshotInfo snapshot() {
279 return get()->snapshot();
280}
281
282// CUDAGraph interactions
283inline void notifyCaptureBegin(
284 int device,
285 CaptureId_t graph_id,
286 MempoolId_t mempool_id) {
287 return get()->notifyCaptureBegin(device, graph_id, mempool_id);
288}
289
290inline void notifyCaptureAboutToEnd(int device, CaptureId_t graph_id) {
291 return get()->notifyCaptureAboutToEnd(device, graph_id);
292}
293
294inline void recordHistory(
295 bool enabled,
296 CreateContextFn context_recorder,
297 size_t alloc_trace_max_entries,
298 bool alloc_trace_record_context) {
299 return get()->recordHistory(
300 enabled,
301 context_recorder,
302 alloc_trace_max_entries,
303 alloc_trace_record_context);
304}
305
306inline void attachOutOfMemoryObserver(OutOfMemoryObserver observer) {
307 return get()->attachOutOfMemoryObserver(observer);
308}
309
310inline void notifyCaptureEnded(int device, CaptureId_t graph_id) {
311 return get()->notifyCaptureEnded(device, graph_id);
312}
313
314inline void notifyCaptureDestroy(int device, MempoolId_t mempool_id) {
315 return get()->notifyCaptureDestroy(device, mempool_id);
316}
317// Not part of CUDA_ALLOCATOR_BACKEND_INTERFACE
318inline std::shared_ptr<void> getIpcDevPtr(std::string handle) {
319 return get()->getIpcDevPtr(handle);
320}
321
322inline std::string name() {
323 return get()->name();
324}
325
326} // namespace CUDACachingAllocator
327} // namespace cuda
328} // namespace c10
329