1 | #include <c10/cuda/CUDACachingAllocator.h> |
2 | |
3 | #include <c10/core/impl/GPUTrace.h> |
4 | #include <c10/cuda/CUDAException.h> |
5 | #include <c10/cuda/CUDAFunctions.h> |
6 | #include <c10/cuda/CUDAGuard.h> |
7 | #include <c10/util/UniqueVoidPtr.h> |
8 | #include <c10/util/flat_hash_map.h> |
9 | #include <c10/util/irange.h> |
10 | #include <c10/util/llvmMathExtras.h> |
11 | |
12 | #include <cuda_runtime_api.h> |
13 | #include <algorithm> |
14 | #include <bitset> |
15 | #include <cstdint> |
16 | #include <deque> |
17 | #include <iterator> |
18 | #include <map> |
19 | #include <memory> |
20 | #include <mutex> |
21 | #include <regex> |
22 | #include <set> |
23 | #include <utility> |
24 | #include <vector> |
25 | |
26 | namespace c10 { |
27 | |
28 | C10_DEFINE_REGISTRY(FreeCudaMemoryCallbacksRegistry, FreeMemoryCallback); |
29 | |
30 | namespace cuda { |
31 | namespace CUDACachingAllocator { |
32 | namespace Native { |
33 | |
34 | // |
35 | // Yet another caching allocator for CUDA device allocations. |
36 | // |
37 | // - Allocations are associated with a stream. Once freed, blocks can be |
38 | // re-allocated on the same stream, but not on any other stream. |
39 | // - The allocator attempts to find the smallest cached block that will fit the |
40 | // requested size. If the block is larger than the requested size, it may be |
41 | // split. If no block is found, the allocator will delegate to cudaMalloc. |
42 | // - If the cudaMalloc fails, the allocator will attempt to free one cached |
43 | // block of sufficient size that is not split and retry the allocation. |
44 | // If this also fails, the allocator will attempt to free all cached blocks |
45 | // that are not split and retry the allocation. |
46 | // - Large (>1MB) and small allocations are stored in separate pools. |
47 | // Small requests are packed into 2MB buffers. Large requests will use the |
48 | // smallest available free block or allocate a new block using cudaMalloc. |
49 | // - To reduce fragmentation, requests between 1MB and 10MB will allocate and |
50 | // split a 20MB block, if no free block of sufficient size is available. |
51 | // - To further reduce fragmentation, blocks >= 200MB are not allowed to be |
52 | // split. These oversize cached blocks will still satisfy requests within |
53 | // 20MB of the oversize cached block size. |
54 | // |
55 | // With this allocator, allocations and frees should logically be considered |
56 | // "usages" of the memory segment associated with streams, just like kernel |
57 | // launches. The programmer must insert the proper synchronization if memory |
58 | // segments are used from multiple streams. |
59 | // |
60 | // The library provides a recordStream() function to help insert the correct |
61 | // synchronization when allocations are used on multiple streams. This will |
62 | // ensure that the block is not reused before each recorded stream completes |
63 | // work. |
64 | // |
65 | |
66 | /** |
67 | * Note [Interaction with CUDA graph capture] |
68 | * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
69 | * Graph capture performs a dry run of a region of execution, freezing all CUDA |
70 | * work (and virtual addresses used during that work) into a "graph." The graph |
71 | * may be "replayed" like a single giant kernel, with greatly reduced CPU |
72 | * overhead as well as modestly improved GPU performance. |
73 | * |
74 | * Because capture bakes in memory addresses, the memory used during capture |
75 | * must be available for the graph to use during replay. DeviceCachingAllocator |
76 | * assigns and frees memory eagerly and dynamically, so if we're not careful |
77 | * about managing graphs' memory, at replay time those memory addresses could be |
78 | * used by other tensors. |
79 | * |
80 | * To guarantee a graph's baked in addresses are safe to reuse in replay, |
81 | * DeviceAllocator satisfies allocations from a graph-private memory pool during |
82 | * capture, and doesn't begin cudaFreeing those addresses until the graph is |
83 | * destroyed. |
84 | * |
85 | * Within the private pool, allocations are freed and reassigned as usual during |
86 | * capture. Memory regions will be used in a consistent order during replay. So |
87 | * a private pool doesn't use memory more wastefully than the default pools |
88 | * during capture, but it does reserve its high-water mark of used memory away |
89 | * from the default pools as long as the capture(s) it served survive |
90 | * (regardless whether those captures are idle or replaying). |
91 | * |
92 | * CUDAGraph's requests for private pools are mediated by |
93 | * DeviceAllocator::notifyCaptureBegin, |
94 | * notifyCaptureAboutToEnd, |
95 | * notifyCaptureEnded, |
96 | * notifyCaptureDestroy. |
97 | */ |
98 | |
99 | constexpr size_t kMinBlockSize = |
100 | 512; // all sizes are rounded to at least 512 bytes |
101 | constexpr size_t kSmallSize = 1048576; // largest "small" allocation is 1 MiB |
102 | constexpr size_t kSmallBuffer = |
103 | 2097152; // "small" allocations are packed in 2 MiB blocks |
104 | constexpr size_t kLargeBuffer = |
105 | 20971520; // "large" allocations may be packed in 20 MiB blocks |
106 | constexpr size_t kMinLargeAlloc = |
107 | 10485760; // allocations between 1 and 10 MiB may use kLargeBuffer |
108 | constexpr size_t kRoundLarge = 2097152; // round up large allocations to 2 MiB |
109 | constexpr size_t kRoundUpPowerOfTwoIntervals = 16; |
110 | |
111 | namespace { |
112 | |
113 | using stream_set = ska::flat_hash_set<cuda::CUDAStream>; |
114 | |
115 | using StatTypes = std::array<bool, static_cast<size_t>(StatType::NUM_TYPES)>; |
116 | |
117 | void update_stat(Stat& stat, int64_t amount) { |
118 | stat.current += amount; |
119 | |
120 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
121 | stat.current >= 0, |
122 | "Negative tracked stat in CUDA allocator (likely logic error)." ); |
123 | |
124 | stat.peak = std::max(stat.current, stat.peak); |
125 | if (amount > 0) { |
126 | stat.allocated += amount; |
127 | } |
128 | if (amount < 0) { |
129 | stat.freed += -amount; |
130 | } |
131 | } |
132 | |
133 | void reset_accumulated_stat(Stat& stat) { |
134 | stat.allocated = 0; |
135 | stat.freed = 0; |
136 | } |
137 | |
138 | void reset_peak_stat(Stat& stat) { |
139 | stat.peak = stat.current; |
140 | } |
141 | |
142 | template <typename Func> |
143 | void for_each_selected_stat_type(const StatTypes& stat_types, Func f) { |
144 | for (const auto stat_type : c10::irange(stat_types.size())) { |
145 | if (stat_types[stat_type]) { |
146 | f(stat_type); |
147 | } |
148 | } |
149 | } |
150 | |
151 | void update_stat_array( |
152 | StatArray& stat_array, |
153 | int64_t amount, |
154 | const StatTypes& stat_types) { |
155 | for_each_selected_stat_type( |
156 | stat_types, [&stat_array, amount](size_t stat_type) { |
157 | update_stat(stat_array[stat_type], amount); |
158 | }); |
159 | } |
160 | |
161 | struct Block; |
162 | struct PrivatePool; |
163 | typedef bool (*Comparison)(const Block*, const Block*); |
164 | |
165 | struct BlockPool { |
166 | BlockPool( |
167 | Comparison comparator, |
168 | bool small, |
169 | PrivatePool* private_pool = nullptr) |
170 | : blocks(comparator), is_small(small), owner_PrivatePool(private_pool) {} |
171 | std::set<Block*, Comparison> blocks; |
172 | const bool is_small; |
173 | PrivatePool* owner_PrivatePool; |
174 | }; |
175 | |
176 | struct HistoryChain { |
177 | History h; |
178 | std::unique_ptr<HistoryChain> next; // when blocks are merged we keep records |
179 | // of what used to be in the block |
180 | }; |
181 | |
182 | struct Block { |
183 | int device; // gpu |
184 | cudaStream_t stream; // allocation stream |
185 | stream_set stream_uses; // streams on which the block was used |
186 | size_t size; // block size in bytes |
187 | size_t requested_size; // memory originally requested |
188 | BlockPool* pool{nullptr}; // owning memory pool |
189 | void* ptr{nullptr}; // memory address |
190 | bool allocated{false}; // in-use flag |
191 | Block* prev{nullptr}; // prev block if split from a larger allocation |
192 | Block* next{nullptr}; // next block if split from a larger allocation |
193 | int event_count{0}; // number of outstanding CUDA events |
194 | int gc_count{0}; // counter for prioritizing older / less useful blocks for |
195 | // garbage collection |
196 | std::unique_ptr<HistoryChain> history; |
197 | HistoryChain* history_last{nullptr}; |
198 | |
199 | Block( |
200 | int device, |
201 | cudaStream_t stream, |
202 | size_t size, |
203 | BlockPool* pool, |
204 | void* ptr) |
205 | : device(device), |
206 | stream(stream), |
207 | stream_uses(), |
208 | size(size), |
209 | requested_size(0), |
210 | pool(pool), |
211 | ptr(ptr) {} |
212 | |
213 | // constructor for search key |
214 | Block(int device, cudaStream_t stream, size_t size) |
215 | : device(device), |
216 | stream(stream), |
217 | stream_uses(), |
218 | size(size), |
219 | requested_size(0) {} |
220 | |
221 | bool is_split() const { |
222 | return (prev != nullptr) || (next != nullptr); |
223 | } |
224 | }; |
225 | |
226 | static bool BlockComparator(const Block* a, const Block* b) { |
227 | if (a->stream != b->stream) { |
228 | return (uintptr_t)a->stream < (uintptr_t)b->stream; |
229 | } |
230 | if (a->size != b->size) { |
231 | return a->size < b->size; |
232 | } |
233 | return (uintptr_t)a->ptr < (uintptr_t)b->ptr; |
234 | } |
235 | |
236 | struct AllocParams { |
237 | AllocParams( |
238 | int device, |
239 | size_t size, |
240 | cudaStream_t stream, |
241 | BlockPool* pool, |
242 | size_t alloc_size, |
243 | DeviceStats& stats) |
244 | : search_key(device, stream, size), |
245 | pool(pool), |
246 | alloc_size(alloc_size), |
247 | block(nullptr), |
248 | err(cudaSuccess) {} |
249 | |
250 | int device() const { |
251 | return search_key.device; |
252 | } |
253 | cudaStream_t stream() const { |
254 | return search_key.stream; |
255 | } |
256 | size_t size() const { |
257 | return search_key.size; |
258 | } |
259 | |
260 | Block search_key; |
261 | BlockPool* pool; |
262 | size_t alloc_size; |
263 | Block* block; |
264 | StatTypes stat_types = {false}; |
265 | cudaError_t err; |
266 | }; |
267 | |
268 | int trimHistoryBefore(Block* block, void* point) { |
269 | int n = 0; |
270 | while (block->history && block->history->h.addr < point) { |
271 | block->history = std::move(block->history->next); |
272 | ++n; |
273 | } |
274 | if (!block->history) { |
275 | block->history_last = nullptr; |
276 | } |
277 | return n; |
278 | } |
279 | |
280 | // Note: cudaEventCreate when concurrently invoked from multiple threads can be |
281 | // very expensive (at least on certain device/driver combinations). Thus, we a) |
282 | // serialize event creation at a per-device level, and b) pool the events to |
283 | // avoid constantly calling cudaEventCreate/cudaEventDestroy. This results in |
284 | // significant improvements in multithreaded workloads with high allocation |
285 | // rates. |
286 | class EventPool { |
287 | public: |
288 | using Event = std::unique_ptr<cudaEvent_t, std::function<void(cudaEvent_t*)>>; |
289 | // TODO: Explicit device count |
290 | EventPool() : pools_(at::cuda::device_count()) {} |
291 | |
292 | Event get(int device) { |
293 | TORCH_INTERNAL_ASSERT(0 <= device); |
294 | TORCH_INTERNAL_ASSERT(device < static_cast<int>(pools_.size())); |
295 | auto& pool = pools_[device]; |
296 | auto destructor = [&pool](cudaEvent_t* event) { |
297 | std::lock_guard<std::mutex> g(pool.mutex_); |
298 | pool.event_pool_.push_back(std::unique_ptr<cudaEvent_t>(event)); |
299 | }; |
300 | |
301 | // Try to acquire an event from the per-device pool. |
302 | { |
303 | std::lock_guard<std::mutex> g(pool.mutex_); |
304 | if (!pool.event_pool_.empty()) { |
305 | auto* event = pool.event_pool_.back().release(); |
306 | pool.event_pool_.pop_back(); |
307 | return Event(event, destructor); |
308 | } |
309 | } |
310 | // otherwise, allocate a new event that will be returned to the pool on |
311 | // destruction. |
312 | auto new_ptr = std::make_unique<cudaEvent_t>(); |
313 | C10_CUDA_CHECK( |
314 | cudaEventCreateWithFlags(new_ptr.get(), cudaEventDisableTiming)); |
315 | |
316 | return Event(new_ptr.release(), destructor); |
317 | } |
318 | |
319 | void empty_cache() { |
320 | for (auto& pool : pools_) { |
321 | std::lock_guard<std::mutex> g(pool.mutex_); |
322 | pool.event_pool_.clear(); |
323 | } |
324 | } |
325 | |
326 | private: |
327 | struct PerDevicePool { |
328 | alignas(64) std::mutex mutex_; |
329 | std::vector<std::unique_ptr<cudaEvent_t>> event_pool_; |
330 | }; |
331 | std::vector<PerDevicePool> pools_; |
332 | }; |
333 | |
334 | // CUDA graphs helper |
335 | struct PrivatePool { |
336 | PrivatePool() |
337 | : use_count(1), |
338 | cudaMalloc_count(0), |
339 | large_blocks(BlockComparator, /*is_small=*/false, this), |
340 | small_blocks(BlockComparator, /*is_small=*/true, this) {} |
341 | PrivatePool(const PrivatePool&) = delete; |
342 | PrivatePool(PrivatePool&&) = delete; |
343 | PrivatePool& operator=(const PrivatePool&) = delete; |
344 | // Number of live graphs using this pool |
345 | int use_count; |
346 | // Number of unfreed cudaMallocs made for this pool. When use_count and |
347 | // cudaMalloc_count drop to zero, we can delete this PrivatePool from |
348 | // graph_pools. |
349 | int cudaMalloc_count; |
350 | // Instead of maintaining private BlockPools here, I could stuff all blocks |
351 | // (private or no) into the top-level large_blocks and small_blocks, and |
352 | // distinguish private blocks by adding a "pool id" check above the stream |
353 | // check in BlockComparator. BlockComparator is performance- critial though, |
354 | // I'd rather not add more logic to it. |
355 | BlockPool large_blocks; |
356 | BlockPool small_blocks; |
357 | }; |
358 | |
359 | struct MempoolIdHash { |
360 | std::size_t operator()(const MempoolId_t& mempool_id) const noexcept { |
361 | return mempool_id.first != 0 ? mempool_id.first : mempool_id.second; |
362 | } |
363 | }; |
364 | |
365 | cudaError_t cudaMallocMaybeCapturing(void** p, size_t size) { |
366 | // TODO: ideally we'd replace this with something like |
367 | // !defined(TORCH_HIP_VERSION) as CUDA <= 10 support was dropped and really |
368 | // this is only a workaround for TORCH_HIP_VERSION not being a sufficient guard |
369 | // to prevent ROCM build breakage. |
370 | #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 |
371 | if (at::cuda::currentStreamCaptureStatusMayInitCtx() == |
372 | at::cuda::CaptureStatus::None) { |
373 | #endif |
374 | return C10_CUDA_ERROR_HANDLED(cudaMalloc(p, size)); |
375 | #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 |
376 | } else { |
377 | // It's ok to capture cudaMallocs, as long as we never cudaFree those |
378 | // addresses before replay. |
379 | // Capturing cudaMalloc behaves nicely: it gives the graph new VA, |
380 | // but is ignored (won't leakily allocate new memory) in replays. |
381 | at::cuda::CUDAStreamCaptureModeGuard g{cudaStreamCaptureModeRelaxed}; |
382 | return C10_CUDA_ERROR_HANDLED(cudaMalloc(p, size)); |
383 | } |
384 | #endif |
385 | } |
386 | |
387 | } // anonymous namespace |
388 | } // namespace Native |
389 | |
390 | // Environment config parser |
391 | // Defined here, rather than its own .cpp file, |
392 | // because parseArgs needs to know kLargeBuffer. |
393 | // Defined outside namespace Native because it's not Native-specific. |
394 | class CachingAllocatorConfig { |
395 | public: |
396 | static size_t max_split_size() { |
397 | return instance().m_max_split_size; |
398 | } |
399 | static double garbage_collection_threshold() { |
400 | return instance().m_garbage_collection_threshold; |
401 | } |
402 | |
403 | // This is used to round-up allocation size to nearest power of 2 divisions. |
404 | // More description below in function roundup_power2_next_division |
405 | // As ane example, if we want 4 divisions between 2's power, this can be done |
406 | // using env variable: PYTORCH_CUDA_ALLOC_CONF=roundup_power2_divisions:4 |
407 | static size_t roundup_power2_divisions(size_t size) { |
408 | size_t log_size = (63 - llvm::countLeadingZeros(size)); |
409 | |
410 | // Our intervals start at 1MB and end at 64GB |
411 | const size_t interval_start = |
412 | 63 - llvm::countLeadingZeros(static_cast<size_t>(1048576)); |
413 | const size_t interval_end = |
414 | 63 - llvm::countLeadingZeros(static_cast<size_t>(68719476736)); |
415 | TORCH_CHECK( |
416 | (interval_end - interval_start == Native::kRoundUpPowerOfTwoIntervals), |
417 | "kRoundUpPowerOfTwoIntervals mismatch" ); |
418 | |
419 | int index = static_cast<int>(log_size) - static_cast<int>(interval_start); |
420 | |
421 | index = std::max(0, index); |
422 | index = std::min( |
423 | index, static_cast<int>(Native::kRoundUpPowerOfTwoIntervals) - 1); |
424 | return instance().m_roundup_power2_divisions[index]; |
425 | } |
426 | |
427 | static CachingAllocatorConfig& instance() { |
428 | static CachingAllocatorConfig* s_instance = ([]() { |
429 | auto inst = new CachingAllocatorConfig(); |
430 | const char* env = getenv("PYTORCH_CUDA_ALLOC_CONF" ); |
431 | inst->parseArgs(env); |
432 | return inst; |
433 | })(); |
434 | return *s_instance; |
435 | } |
436 | |
437 | void parseArgs(const char* env); |
438 | |
439 | private: |
440 | CachingAllocatorConfig() |
441 | : m_max_split_size(std::numeric_limits<size_t>::max()), |
442 | m_garbage_collection_threshold(0) { |
443 | m_roundup_power2_divisions.assign(Native::kRoundUpPowerOfTwoIntervals, 0); |
444 | } |
445 | |
446 | void lexArgs(const char* env, std::vector<std::string>& config); |
447 | void consumeToken( |
448 | const std::vector<std::string>& config, |
449 | size_t i, |
450 | const char c); |
451 | size_t parseMaxSplitSize(const std::vector<std::string>& config, size_t i); |
452 | size_t parseGarbageCollectionThreshold( |
453 | const std::vector<std::string>& config, |
454 | size_t i); |
455 | size_t parseRoundUpPower2Divisions( |
456 | const std::vector<std::string>& config, |
457 | size_t i); |
458 | size_t parseAllocatorConfig( |
459 | const std::vector<std::string>& config, |
460 | size_t i, |
461 | bool& used_cudaMallocAsync); |
462 | |
463 | std::atomic<size_t> m_max_split_size; |
464 | std::vector<size_t> m_roundup_power2_divisions; |
465 | std::atomic<double> m_garbage_collection_threshold; |
466 | }; |
467 | |
468 | void CachingAllocatorConfig::lexArgs( |
469 | const char* env, |
470 | std::vector<std::string>& config) { |
471 | std::vector<char> buf; |
472 | |
473 | size_t env_length = strlen(env); |
474 | for (size_t i = 0; i < env_length; i++) { |
475 | if (env[i] == ',' || env[i] == ':' || env[i] == '[' || env[i] == ']') { |
476 | if (buf.size() != 0) { |
477 | config.emplace_back(buf.begin(), buf.end()); |
478 | buf.clear(); |
479 | } |
480 | config.emplace_back(1, env[i]); |
481 | } else if (env[i] != ' ') { |
482 | buf.emplace_back(static_cast<char>(env[i])); |
483 | } |
484 | } |
485 | if (!buf.empty()) { |
486 | config.emplace_back(buf.begin(), buf.end()); |
487 | } |
488 | } |
489 | |
490 | void CachingAllocatorConfig::consumeToken( |
491 | const std::vector<std::string>& config, |
492 | size_t i, |
493 | const char c) { |
494 | TORCH_CHECK( |
495 | i < config.size() && config[i].compare(std::string(1, c)) == 0, |
496 | "Error parsing CachingAllocator settings, expected " , |
497 | c, |
498 | "" ); |
499 | } |
500 | |
501 | size_t CachingAllocatorConfig::parseMaxSplitSize( |
502 | const std::vector<std::string>& config, |
503 | size_t i) { |
504 | consumeToken(config, ++i, ':'); |
505 | if (++i < config.size()) { |
506 | size_t val1 = stoi(config[i]); |
507 | TORCH_CHECK( |
508 | val1 > Native::kLargeBuffer / (1024 * 1024), |
509 | "CachingAllocator option max_split_size_mb too small, must be > " , |
510 | Native::kLargeBuffer / (1024 * 1024), |
511 | "" ); |
512 | val1 = std::max(val1, Native::kLargeBuffer / (1024 * 1024)); |
513 | val1 = std::min(val1, (std::numeric_limits<size_t>::max() / (1024 * 1024))); |
514 | m_max_split_size = val1 * 1024 * 1024; |
515 | } else { |
516 | TORCH_CHECK(false, "Error, expecting max_split_size_mb value" , "" ); |
517 | } |
518 | return i; |
519 | } |
520 | |
521 | size_t CachingAllocatorConfig::parseGarbageCollectionThreshold( |
522 | const std::vector<std::string>& config, |
523 | size_t i) { |
524 | consumeToken(config, ++i, ':'); |
525 | if (++i < config.size()) { |
526 | double val1 = stod(config[i]); |
527 | TORCH_CHECK( |
528 | val1 > 0, "garbage_collect_threshold too small, set it 0.0~1.0" , "" ); |
529 | TORCH_CHECK( |
530 | val1 < 1.0, "garbage_collect_threshold too big, set it 0.0~1.0" , "" ); |
531 | m_garbage_collection_threshold = val1; |
532 | } else { |
533 | TORCH_CHECK( |
534 | false, "Error, expecting garbage_collection_threshold value" , "" ); |
535 | } |
536 | return i; |
537 | } |
538 | |
539 | size_t CachingAllocatorConfig::parseRoundUpPower2Divisions( |
540 | const std::vector<std::string>& config, |
541 | size_t i) { |
542 | consumeToken(config, ++i, ':'); |
543 | bool first_value = true; |
544 | |
545 | if (++i < config.size()) { |
546 | if (config[i].compare("[" ) == 0) { |
547 | size_t last_index = 0; |
548 | while (++i < config.size() && config[i].compare("]" ) != 0) { |
549 | std::string val1 = config[i]; |
550 | size_t val2 = 0; |
551 | |
552 | consumeToken(config, ++i, ':'); |
553 | if (++i < config.size()) { |
554 | val2 = stoi(config[i]); |
555 | } else { |
556 | TORCH_CHECK( |
557 | false, "Error parsing roundup_power2_divisions value" , "" ); |
558 | } |
559 | TORCH_CHECK( |
560 | llvm::isPowerOf2_64(val2), |
561 | "For roundups, the divisons has to be power of 2 " , |
562 | "" ); |
563 | |
564 | if (val1.compare(">" ) == 0) { |
565 | std::fill( |
566 | std::next( |
567 | m_roundup_power2_divisions.begin(), |
568 | static_cast<std::vector<unsigned long>::difference_type>( |
569 | last_index)), |
570 | m_roundup_power2_divisions.end(), |
571 | val2); |
572 | } else { |
573 | size_t val1_long = stoul(val1); |
574 | TORCH_CHECK( |
575 | llvm::isPowerOf2_64(val1_long), |
576 | "For roundups, the intervals have to be power of 2 " , |
577 | "" ); |
578 | |
579 | size_t index = 63 - llvm::countLeadingZeros(val1_long); |
580 | index = std::max((size_t)0, index); |
581 | index = std::min(index, m_roundup_power2_divisions.size() - 1); |
582 | |
583 | if (first_value) { |
584 | std::fill( |
585 | m_roundup_power2_divisions.begin(), |
586 | std::next( |
587 | m_roundup_power2_divisions.begin(), |
588 | static_cast<std::vector<unsigned long>::difference_type>( |
589 | index)), |
590 | val2); |
591 | first_value = false; |
592 | } |
593 | if (index < m_roundup_power2_divisions.size()) { |
594 | m_roundup_power2_divisions[index] = val2; |
595 | } |
596 | last_index = index; |
597 | } |
598 | |
599 | if (config[i + 1].compare("]" ) != 0) { |
600 | consumeToken(config, ++i, ','); |
601 | } |
602 | } |
603 | } else { // Keep this for backwards compatibility |
604 | size_t val1 = stoi(config[i]); |
605 | TORCH_CHECK( |
606 | llvm::isPowerOf2_64(val1), |
607 | "For roundups, the divisons has to be power of 2 " , |
608 | "" ); |
609 | std::fill( |
610 | m_roundup_power2_divisions.begin(), |
611 | m_roundup_power2_divisions.end(), |
612 | val1); |
613 | } |
614 | } else { |
615 | TORCH_CHECK(false, "Error, expecting roundup_power2_divisions value" , "" ); |
616 | } |
617 | return i; |
618 | } |
619 | |
620 | size_t CachingAllocatorConfig::parseAllocatorConfig( |
621 | const std::vector<std::string>& config, |
622 | size_t i, |
623 | bool& used_cudaMallocAsync) { |
624 | consumeToken(config, ++i, ':'); |
625 | if (++i < config.size()) { |
626 | TORCH_CHECK( |
627 | ((config[i] == "native" ) || (config[i] == "cudaMallocAsync" )), |
628 | "Unknown allocator backend, " |
629 | "options are native and cudaMallocAsync" ); |
630 | used_cudaMallocAsync = (config[i] == "cudaMallocAsync" ); |
631 | if (used_cudaMallocAsync) { |
632 | #if CUDA_VERSION >= 11040 |
633 | int version; |
634 | C10_CUDA_CHECK(cudaDriverGetVersion(&version)); |
635 | TORCH_CHECK( |
636 | version >= 11040, |
637 | "backend:cudaMallocAsync requires CUDA runtime " |
638 | "11.4 or newer, but cudaDriverGetVersion returned " , |
639 | version); |
640 | #else |
641 | TORCH_CHECK( |
642 | false, |
643 | "backend:cudaMallocAsync requires PyTorch to be built with " |
644 | "CUDA 11.4 or newer, but CUDA_VERSION is " , |
645 | CUDA_VERSION); |
646 | #endif |
647 | } |
648 | TORCH_INTERNAL_ASSERT( |
649 | config[i] == get()->name(), |
650 | "Allocator backend parsed at runtime != " |
651 | "allocator backend parsed at load time" ); |
652 | } else { |
653 | TORCH_CHECK(false, "Error parsing backend value" , "" ); |
654 | } |
655 | return i; |
656 | } |
657 | |
658 | void CachingAllocatorConfig::parseArgs(const char* env) { |
659 | // If empty, set the default values |
660 | m_max_split_size = std::numeric_limits<size_t>::max(); |
661 | m_roundup_power2_divisions.assign(Native::kRoundUpPowerOfTwoIntervals, 0); |
662 | m_garbage_collection_threshold = 0; |
663 | bool used_cudaMallocAsync = false; |
664 | bool used_native_specific_option = false; |
665 | |
666 | if (env == nullptr) { |
667 | return; |
668 | } |
669 | |
670 | std::vector<std::string> config; |
671 | lexArgs(env, config); |
672 | |
673 | for (size_t i = 0; i < config.size(); i++) { |
674 | if (config[i].compare("max_split_size_mb" ) == 0) { |
675 | i = parseMaxSplitSize(config, i); |
676 | used_native_specific_option = true; |
677 | } else if (config[i].compare("garbage_collection_threshold" ) == 0) { |
678 | i = parseGarbageCollectionThreshold(config, i); |
679 | used_native_specific_option = true; |
680 | } else if (config[i].compare("roundup_power2_divisions" ) == 0) { |
681 | i = parseRoundUpPower2Divisions(config, i); |
682 | used_native_specific_option = true; |
683 | } else if (config[i].compare("backend" ) == 0) { |
684 | i = parseAllocatorConfig(config, i, used_cudaMallocAsync); |
685 | } else { |
686 | TORCH_CHECK(false, "Unrecognized CachingAllocator option: " , config[i]); |
687 | } |
688 | |
689 | if (i + 1 < config.size()) { |
690 | consumeToken(config, ++i, ','); |
691 | } |
692 | } |
693 | |
694 | if (used_cudaMallocAsync && used_native_specific_option) { |
695 | TORCH_WARN( |
696 | "backend:cudaMallocAsync ignores max_split_size_mb, roundup_bypass_threshold_mb," |
697 | "roundup_power2_divisions, and garbage_collect_threshold." ); |
698 | } |
699 | } |
700 | |
701 | namespace Native { |
702 | |
703 | class DeviceCachingAllocator { |
704 | private: |
705 | // lock around all operations |
706 | mutable std::recursive_mutex mutex; |
707 | |
708 | // device statistics |
709 | DeviceStats stats; |
710 | |
711 | // unallocated cached blocks larger than 1 MB |
712 | BlockPool large_blocks; |
713 | |
714 | // unallocated cached blocks 1 MB or smaller |
715 | BlockPool small_blocks; |
716 | |
717 | // allocated or in use by a stream. Holds all active allocations, |
718 | // whether they came from graph_pools or one of the BlockPools above. |
719 | ska::flat_hash_set<Block*> active_blocks; |
720 | |
721 | // captures_underway tracks if a capture might be underway on any stream. |
722 | // Most of the time it's zero, in which case malloc can avoid calling |
723 | // cudaStreamGetCaptureInfo in the hot path. |
724 | int captures_underway = 0; |
725 | // See free() for this thing's purpose |
726 | std::vector<Block*> needs_events_deferred_until_no_capture; |
727 | // outstanding cuda events |
728 | ska::flat_hash_map< |
729 | cuda::CUDAStream, |
730 | std::deque<std::pair<EventPool::Event, Block*>>> |
731 | cuda_events; |
732 | |
733 | // record used memory. |
734 | size_t total_allocated_memory = 0; |
735 | |
736 | size_t allowed_memory_maximum = 0; |
737 | |
738 | bool set_fraction = false; |
739 | |
740 | bool record_history = false; |
741 | std::atomic<CreateContextFn> context_recorder_; |
742 | size_t alloc_trace_next = 0; |
743 | bool alloc_trace_record_context_ = false; |
744 | size_t alloc_trace_max_entries_ = 1; |
745 | std::vector<TraceEntry>* |
746 | alloc_trace; // pointer because we need to intentionally leak this on |
747 | // deallocation it can hold references to Python state which |
748 | // will already be destroyed when we are in exit handlers |
749 | |
750 | // Members specific to CUDA graphs |
751 | |
752 | // Private pools for CUDA graphs |
753 | ska::flat_hash_map<MempoolId_t, std::unique_ptr<PrivatePool>, MempoolIdHash> |
754 | graph_pools; |
755 | // Pools no longer referenced by any graph. Their BlockPools are eligible for |
756 | // free_blocks. Can't be a vector or deque because we might erase entries in |
757 | // any order. Could be an std::list, but we don't care much, access and |
758 | // insert/erase are rare. |
759 | ska::flat_hash_map<MempoolId_t, PrivatePool*, MempoolIdHash> |
760 | graph_pools_freeable; |
761 | |
762 | // Maps a capturing stream to its assigned private pool, |
763 | // in case we want multiple captures to share the same pool |
764 | ska::flat_hash_map<CaptureId_t, MempoolId_t> capture_to_pool_map; |
765 | |
766 | // XXX - maybe we should generalize and have multiple events |
767 | std::vector<OutOfMemoryObserver> oom_observers_; |
768 | |
769 | public: |
770 | DeviceCachingAllocator() |
771 | : large_blocks(BlockComparator, /*is_small=*/false), |
772 | small_blocks(BlockComparator, /*is_small=*/true), |
773 | alloc_trace(new std::vector<TraceEntry>()) { |
774 | stats.max_split_size = CachingAllocatorConfig::max_split_size(); |
775 | context_recorder_.store(nullptr); |
776 | } |
777 | |
778 | void recordHistory( |
779 | bool enabled, |
780 | CreateContextFn context_recorder, |
781 | size_t alloc_trace_max_entries, |
782 | bool alloc_trace_record_context) { |
783 | std::unique_lock<std::recursive_mutex> lock(mutex); |
784 | record_history = enabled; |
785 | context_recorder_.store(context_recorder); |
786 | alloc_trace_max_entries_ = std::max(size_t(1), alloc_trace_max_entries); |
787 | alloc_trace_record_context_ = alloc_trace_record_context; |
788 | alloc_trace_next = 0; |
789 | alloc_trace->clear(); |
790 | } |
791 | |
792 | void attachOutOfMemoryObserver(OutOfMemoryObserver observer) { |
793 | oom_observers_.emplace_back(std::move(observer)); |
794 | } |
795 | |
796 | // All public methods (except the above) acquire the allocator mutex. |
797 | // Thus, do not call a public method from another public method. |
798 | |
799 | Block* malloc(int device, size_t orig_size, cudaStream_t stream) { |
800 | // done outside the lock because we don't know what locks the recorder needs |
801 | // to have... |
802 | CreateContextFn context_recorder = context_recorder_.load(); |
803 | std::shared_ptr<Context> context = |
804 | context_recorder ? context_recorder() : nullptr; |
805 | |
806 | std::unique_lock<std::recursive_mutex> lock(mutex); |
807 | |
808 | if (C10_LIKELY(captures_underway == 0)) { |
809 | // Processes end-of-life events for outstanding allocations used on |
810 | // multiple streams (checks if their GPU-side uses are complete and |
811 | // recycles their memory if so) |
812 | // |
813 | // Q. Why skip process_events if a capture might be underway? |
814 | // A. process_events involves cudaEventQueries, illegal during CUDA graph |
815 | // capture. |
816 | // Dumb simple solution: defer reclaiming these allocations until after |
817 | // capture. Cross-stream memory use is uncommon, so the deferral's |
818 | // effect on memory use during capture should be small. |
819 | process_events(); |
820 | } |
821 | size_t size = round_size(orig_size); |
822 | auto& pool = get_pool(size, stream); |
823 | const size_t alloc_size = get_allocation_size(size); |
824 | AllocParams params(device, size, stream, &pool, alloc_size, stats); |
825 | params.stat_types[static_cast<size_t>(StatType::AGGREGATE)] = true; |
826 | params.stat_types[static_cast<size_t>(get_stat_type_for_pool(pool))] = true; |
827 | |
828 | // First, try to get a block from the existing pool. |
829 | bool block_found = |
830 | // Search pool |
831 | get_free_block(params) |
832 | // Trigger callbacks and retry search |
833 | || (trigger_free_memory_callbacks(params) && get_free_block(params)); |
834 | |
835 | // Can't reuse an existing block; try to get a new one. |
836 | if (!block_found) { |
837 | // Do garbage collection if the flag is set. |
838 | if (C10_UNLIKELY( |
839 | set_fraction && |
840 | CachingAllocatorConfig::garbage_collection_threshold() > 0.0)) { |
841 | garbage_collect_cached_blocks(); |
842 | } |
843 | // Attempt allocate |
844 | block_found = alloc_block(params, false) |
845 | // Free enough available cached blocks to satisfy alloc and retry |
846 | // alloc. |
847 | || (release_available_cached_blocks(params) && |
848 | alloc_block(params, false)) |
849 | // Free all non-split cached blocks and retry alloc. |
850 | || (C10_LIKELY(captures_underway == 0) && release_cached_blocks() && |
851 | alloc_block(params, true)); |
852 | if (record_history && block_found) { |
853 | record_trace( |
854 | TraceEntry::SEGMENT_ALLOC, |
855 | int64_t(params.block->ptr), |
856 | params.block->size, |
857 | params.stream(), |
858 | context); |
859 | } |
860 | } |
861 | |
862 | if (!block_found) { |
863 | // For any error code other than cudaErrorMemoryAllocation, |
864 | // alloc_block should have thrown an exception already. |
865 | TORCH_INTERNAL_ASSERT(params.err == cudaErrorMemoryAllocation); |
866 | |
867 | size_t device_free; |
868 | size_t device_total; |
869 | C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total)); |
870 | std::string allowed_info; |
871 | |
872 | if (set_fraction) { |
873 | allowed_info = format_size(allowed_memory_maximum) + " allowed; " ; |
874 | } |
875 | |
876 | if (record_history) { |
877 | record_trace( |
878 | TraceEntry::OOM, |
879 | device_free, |
880 | params.size(), |
881 | params.stream(), |
882 | std::move(context)); |
883 | } |
884 | stats.num_ooms += 1; |
885 | |
886 | c10::reportOutOfMemoryToProfiler( |
887 | size, |
888 | stats.allocated_bytes[static_cast<int64_t>(StatType::AGGREGATE)] |
889 | .current, |
890 | stats.reserved_bytes[static_cast<int64_t>(StatType::AGGREGATE)] |
891 | .current, |
892 | c10::Device(c10::DeviceType::CUDA, static_cast<DeviceIndex>(device))); |
893 | for (const auto& obs : oom_observers_) { |
894 | obs(device, |
895 | alloc_size, |
896 | set_fraction ? allowed_memory_maximum : device_total, |
897 | device_free); |
898 | } |
899 | // "total capacity": total global memory on GPU |
900 | // "allowed": memory is allowed to use, which set by fraction. |
901 | // "already allocated": memory allocated by the program using the |
902 | // caching allocator |
903 | // "free": free memory as reported by the CUDA API |
904 | // "cached": memory held by the allocator but not used by the program |
905 | // |
906 | // The "allocated" amount does not include memory allocated outside |
907 | // of the caching allocator, such as memory allocated by other programs |
908 | // or memory held by the driver. |
909 | // |
910 | // The sum of "allocated" + "free" + "cached" may be less than the |
911 | // total capacity due to memory held by the driver and usage by other |
912 | // programs. |
913 | // |
914 | // Note that at this point free_cached_blocks has already returned all |
915 | // possible "cached" memory to the driver. The only remaining "cached" |
916 | // memory is split from a larger block that is partially in-use. |
917 | TORCH_CHECK_WITH( |
918 | OutOfMemoryError, |
919 | false, |
920 | "CUDA out of memory. Tried to allocate " , |
921 | format_size(alloc_size), |
922 | " (GPU " , |
923 | device, |
924 | "; " , |
925 | format_size(device_total), |
926 | " total capacity; " , |
927 | format_size( |
928 | stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)] |
929 | .current), |
930 | " already allocated; " , |
931 | format_size(device_free), |
932 | " free; " , |
933 | allowed_info, |
934 | format_size( |
935 | stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)] |
936 | .current), |
937 | " reserved in total by PyTorch)" , |
938 | " If reserved memory is >> allocated memory try setting max_split_size_mb to avoid" |
939 | " fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF" , |
940 | "" ); |
941 | } |
942 | |
943 | TORCH_INTERNAL_ASSERT( |
944 | params.err == cudaSuccess && params.block != nullptr && |
945 | params.block->ptr != nullptr); |
946 | Block* block = params.block; |
947 | Block* remaining = nullptr; |
948 | |
949 | const bool already_split = block->is_split(); |
950 | if (should_split(block, size)) { |
951 | remaining = block; |
952 | |
953 | block = new Block(device, stream, size, &pool, block->ptr); |
954 | block->prev = remaining->prev; |
955 | if (block->prev) { |
956 | block->prev->next = block; |
957 | } |
958 | block->next = remaining; |
959 | |
960 | remaining->prev = block; |
961 | remaining->ptr = static_cast<char*>(remaining->ptr) + size; |
962 | remaining->size -= size; |
963 | bool inserted = pool.blocks.insert(remaining).second; |
964 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted); |
965 | |
966 | if (record_history) { |
967 | trimHistoryBefore(remaining, (char*)block->ptr + size); |
968 | } |
969 | |
970 | if (already_split) { |
971 | // An already-split inactive block is being shrunk by size bytes. |
972 | update_stat_array( |
973 | stats.inactive_split_bytes, |
974 | -static_cast<std::int64_t>(block->size), |
975 | params.stat_types); |
976 | } else { |
977 | // A new split inactive block is being created from a previously unsplit |
978 | // block, size remaining->size bytes. |
979 | for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) { |
980 | update_stat( |
981 | stats.inactive_split_bytes[stat_type], |
982 | static_cast<std::int64_t>(remaining->size)); |
983 | update_stat(stats.inactive_split[stat_type], 1); |
984 | }); |
985 | } |
986 | |
987 | } else if (already_split) { |
988 | // An already-split block is becoming active |
989 | for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) { |
990 | update_stat( |
991 | stats.inactive_split_bytes[stat_type], |
992 | -static_cast<std::int64_t>(block->size)); |
993 | update_stat(stats.inactive_split[stat_type], -1); |
994 | }); |
995 | } |
996 | |
997 | block->allocated = true; |
998 | block->requested_size = orig_size; |
999 | if (record_history) { |
1000 | trimHistoryBefore(block, (char*)block->ptr + size); |
1001 | block->history = std::make_unique<HistoryChain>(HistoryChain{ |
1002 | History{block->ptr, orig_size, std::move(context)}, |
1003 | std::move(block->history)}); |
1004 | if (!block->history_last) { |
1005 | block->history_last = block->history.get(); |
1006 | } |
1007 | record_trace( |
1008 | TraceEntry::ALLOC, |
1009 | int64_t(block->ptr), |
1010 | orig_size, |
1011 | block->stream, |
1012 | block->history->h.context); |
1013 | } |
1014 | |
1015 | bool inserted = active_blocks.insert(block).second; |
1016 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inserted); |
1017 | |
1018 | for_each_selected_stat_type(params.stat_types, [&](size_t stat_type) { |
1019 | update_stat(stats.allocation[stat_type], 1); |
1020 | update_stat( |
1021 | stats.allocated_bytes[stat_type], |
1022 | static_cast<std::int64_t>(block->size)); |
1023 | update_stat(stats.active[stat_type], 1); |
1024 | update_stat( |
1025 | stats.active_bytes[stat_type], |
1026 | static_cast<std::int64_t>(block->size)); |
1027 | update_stat( |
1028 | stats.requested_bytes[stat_type], |
1029 | static_cast<std::int64_t>(block->requested_size)); |
1030 | }); |
1031 | if (block->size >= CachingAllocatorConfig::max_split_size()) |
1032 | update_stat(stats.oversize_allocations, 1); |
1033 | |
1034 | c10::reportMemoryUsageToProfiler( |
1035 | block->ptr, |
1036 | block->size, |
1037 | stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)].current, |
1038 | stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)].current, |
1039 | c10::Device(c10::DeviceType::CUDA, device)); |
1040 | |
1041 | return block; |
1042 | } |
1043 | |
1044 | void free(Block* block) { |
1045 | std::lock_guard<std::recursive_mutex> lock(mutex); |
1046 | |
1047 | block->allocated = false; |
1048 | |
1049 | // following logic might modifying underlaying Block, causing the size |
1050 | // changed. We store ahead for reporting |
1051 | auto orig_block_ptr = block->ptr; |
1052 | auto orig_block_size = block->size; |
1053 | |
1054 | StatTypes stat_types = {false}; |
1055 | stat_types[static_cast<size_t>(StatType::AGGREGATE)] = true; |
1056 | stat_types[static_cast<size_t>(get_stat_type_for_pool(*(block->pool)))] = |
1057 | true; |
1058 | for_each_selected_stat_type(stat_types, [&](size_t stat_type) { |
1059 | update_stat(stats.allocation[stat_type], -1); |
1060 | update_stat( |
1061 | stats.allocated_bytes[stat_type], |
1062 | -static_cast<std::int64_t>(block->size)); |
1063 | }); |
1064 | if (block->history) { |
1065 | record_trace( |
1066 | TraceEntry::FREE_REQUESTED, |
1067 | int64_t(block->ptr), |
1068 | block->history->h.real_size, |
1069 | block->stream, |
1070 | block->history->h.context); |
1071 | } |
1072 | if (block->size >= CachingAllocatorConfig::max_split_size()) |
1073 | update_stat(stats.oversize_allocations, -1); |
1074 | |
1075 | if (!block->stream_uses.empty()) { |
1076 | if (C10_UNLIKELY(captures_underway)) { |
1077 | // It's forbidden to cudaEventQuery an event recorded during CUDA graph |
1078 | // capture. We conservatively defer recording end-of-life events until |
1079 | // the next call to process_events() (which won't happen until no |
1080 | // captures are underway) |
1081 | needs_events_deferred_until_no_capture.push_back(block); |
1082 | } else { |
1083 | insert_events(block); |
1084 | } |
1085 | } else { |
1086 | free_block(block); |
1087 | } |
1088 | |
1089 | c10::reportMemoryUsageToProfiler( |
1090 | orig_block_ptr, |
1091 | -orig_block_size, |
1092 | stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)].current, |
1093 | stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)].current, |
1094 | c10::Device(c10::DeviceType::CUDA, block->device)); |
1095 | } |
1096 | |
1097 | void* getBaseAllocation(Block* block, size_t* outSize) { |
1098 | std::lock_guard<std::recursive_mutex> lock(mutex); |
1099 | while (block->prev) { |
1100 | block = block->prev; |
1101 | } |
1102 | void* basePtr = block->ptr; |
1103 | if (outSize) { |
1104 | size_t size = 0; |
1105 | while (block) { |
1106 | size += block->size; |
1107 | block = block->next; |
1108 | } |
1109 | *outSize = size; |
1110 | } |
1111 | return basePtr; |
1112 | } |
1113 | |
1114 | void recordStream(Block* block, cuda::CUDAStream stream) { |
1115 | std::lock_guard<std::recursive_mutex> lock(mutex); |
1116 | if (stream.stream() == block->stream) { |
1117 | // ignore uses on the allocation stream, since those don't require any |
1118 | // special synchronization |
1119 | return; |
1120 | } |
1121 | block->stream_uses.insert(stream); |
1122 | } |
1123 | |
1124 | /** set memory fraction to limit maximum allocated memory **/ |
1125 | void setMemoryFraction(double fraction) { |
1126 | size_t device_free; |
1127 | size_t device_total; |
1128 | C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total)); |
1129 | allowed_memory_maximum = static_cast<size_t>(fraction * device_total); |
1130 | set_fraction = true; |
1131 | } |
1132 | |
1133 | /** returns cached blocks to the system allocator **/ |
1134 | void emptyCache() { |
1135 | std::lock_guard<std::recursive_mutex> lock(mutex); |
1136 | release_cached_blocks(); |
1137 | } |
1138 | |
1139 | /** Retrieves size of largest unused block held by the memory cache **/ |
1140 | void cacheInfo(size_t* largest) { |
1141 | std::lock_guard<std::recursive_mutex> lock(mutex); |
1142 | if (*largest == |
1143 | 0) { // make an initial guess if a zero *largest is passed in |
1144 | size_t tmp_bytes; |
1145 | C10_CUDA_CHECK(cudaMemGetInfo( |
1146 | largest, // Use free memory as an optimistic initial guess of *largest |
1147 | &tmp_bytes)); |
1148 | } |
1149 | cache_info_aux(large_blocks, largest); |
1150 | cache_info_aux(small_blocks, largest); |
1151 | for (const auto& gp : graph_pools) { |
1152 | cache_info_aux(gp.second->large_blocks, largest); |
1153 | cache_info_aux(gp.second->small_blocks, largest); |
1154 | } |
1155 | } |
1156 | |
1157 | /** Returns a copy of the memory allocator stats **/ |
1158 | DeviceStats getStats() { |
1159 | std::lock_guard<std::recursive_mutex> lock(mutex); |
1160 | return stats; |
1161 | } |
1162 | |
1163 | /** Resets the historical accumulation stats for the device **/ |
1164 | void resetAccumulatedStats() { |
1165 | std::lock_guard<std::recursive_mutex> lock(mutex); |
1166 | |
1167 | for (const auto statType : |
1168 | c10::irange(static_cast<size_t>(StatType::NUM_TYPES))) { |
1169 | reset_accumulated_stat(stats.allocation[statType]); |
1170 | reset_accumulated_stat(stats.segment[statType]); |
1171 | reset_accumulated_stat(stats.active[statType]); |
1172 | reset_accumulated_stat(stats.inactive_split[statType]); |
1173 | reset_accumulated_stat(stats.allocated_bytes[statType]); |
1174 | reset_accumulated_stat(stats.reserved_bytes[statType]); |
1175 | reset_accumulated_stat(stats.active_bytes[statType]); |
1176 | reset_accumulated_stat(stats.inactive_split_bytes[statType]); |
1177 | reset_accumulated_stat(stats.requested_bytes[statType]); |
1178 | } |
1179 | |
1180 | stats.num_alloc_retries = 0; |
1181 | stats.num_ooms = 0; |
1182 | reset_accumulated_stat(stats.oversize_allocations); |
1183 | reset_accumulated_stat(stats.oversize_segments); |
1184 | } |
1185 | |
1186 | /** Resets the historical peak stats for the device **/ |
1187 | void resetPeakStats() { |
1188 | std::lock_guard<std::recursive_mutex> lock(mutex); |
1189 | |
1190 | for (const auto statType : |
1191 | c10::irange(static_cast<size_t>(StatType::NUM_TYPES))) { |
1192 | reset_peak_stat(stats.allocation[statType]); |
1193 | reset_peak_stat(stats.segment[statType]); |
1194 | reset_peak_stat(stats.active[statType]); |
1195 | reset_peak_stat(stats.inactive_split[statType]); |
1196 | reset_peak_stat(stats.allocated_bytes[statType]); |
1197 | reset_peak_stat(stats.reserved_bytes[statType]); |
1198 | reset_peak_stat(stats.active_bytes[statType]); |
1199 | reset_peak_stat(stats.inactive_split_bytes[statType]); |
1200 | reset_peak_stat(stats.requested_bytes[statType]); |
1201 | } |
1202 | reset_peak_stat(stats.oversize_allocations); |
1203 | reset_peak_stat(stats.oversize_segments); |
1204 | } |
1205 | |
1206 | /** Dump a complete snapshot of the memory held by the allocator. Potentially |
1207 | * VERY expensive. **/ |
1208 | std::vector<SegmentInfo> snapshot() { |
1209 | std::lock_guard<std::recursive_mutex> lock(mutex); |
1210 | |
1211 | size_t total_active = 0; |
1212 | std::vector<SegmentInfo> result; |
1213 | const auto all_blocks = get_all_blocks(); |
1214 | for (const Block* const head_block : all_blocks) { |
1215 | if (head_block->prev != nullptr) { |
1216 | continue; |
1217 | } |
1218 | result.emplace_back(); |
1219 | SegmentInfo& segment_info = result.back(); |
1220 | segment_info.device = head_block->device; |
1221 | segment_info.address = reinterpret_cast<int64_t>(head_block->ptr); |
1222 | segment_info.stream = head_block->stream; |
1223 | segment_info.is_large = (!head_block->pool->is_small); |
1224 | |
1225 | const Block* block = head_block; |
1226 | while (block != nullptr) { |
1227 | segment_info.blocks.emplace_back(); |
1228 | BlockInfo& block_info = segment_info.blocks.back(); |
1229 | |
1230 | block_info.size = block->size; |
1231 | block_info.requested_size = block->requested_size; |
1232 | block_info.allocated = block->allocated; |
1233 | block_info.active = block->allocated || (block->event_count > 0) || |
1234 | !block->stream_uses.empty(); |
1235 | |
1236 | segment_info.total_size += block_info.size; |
1237 | if (block_info.allocated) { |
1238 | segment_info.allocated_size += block_info.size; |
1239 | } |
1240 | if (block_info.active) { |
1241 | segment_info.active_size += block_info.size; |
1242 | segment_info.requested_size += block_info.requested_size; |
1243 | } |
1244 | HistoryChain* h = block->history.get(); |
1245 | while (h) { |
1246 | block_info.history.push_back(h->h); |
1247 | h = h->next.get(); |
1248 | } |
1249 | block = block->next; |
1250 | } |
1251 | total_active += segment_info.active_size; |
1252 | } |
1253 | |
1254 | std::sort( |
1255 | result.begin(), |
1256 | result.end(), |
1257 | [](const SegmentInfo& a, const SegmentInfo& b) { |
1258 | return a.address < b.address; |
1259 | }); |
1260 | |
1261 | if (record_history) { |
1262 | record_trace(TraceEntry::SNAPSHOT, 0, total_active, 0, nullptr); |
1263 | } |
1264 | return result; |
1265 | } |
1266 | |
1267 | std::vector<TraceEntry> trace() { |
1268 | std::lock_guard<std::recursive_mutex> lock(mutex); |
1269 | std::vector<TraceEntry> result; |
1270 | result.reserve(alloc_trace->size()); |
1271 | result.insert( |
1272 | result.end(), |
1273 | alloc_trace->begin() + alloc_trace_next, |
1274 | alloc_trace->end()); |
1275 | result.insert( |
1276 | result.end(), |
1277 | alloc_trace->begin(), |
1278 | alloc_trace->begin() + alloc_trace_next); |
1279 | return result; |
1280 | } |
1281 | |
1282 | // This function takes the size and number of divisions argument and rounds |
1283 | // up the size argument for the nearest power-of-2 division. |
1284 | // For example, if we need to round-up 1200 and number of divisions is 4, |
1285 | // the size 1200 lies between 1024 and 2048 and if we do 4 divisions between |
1286 | // them, the values are 1024, 1280, 1536, and 1792. So the function will |
1287 | // return 1280 as the nearest ceiling of power-2 divison. |
1288 | static size_t roundup_power2_next_division(size_t size, size_t divisions) { |
1289 | if (C10_UNLIKELY(size <= 4 || divisions <= 1)) { |
1290 | return size; |
1291 | } |
1292 | if (llvm::isPowerOf2_64(size)) { |
1293 | return size; |
1294 | } |
1295 | |
1296 | // divide the space between these 2's power into equal divisions |
1297 | // If division is zero, return the power-of-2 ceiling. |
1298 | size_t power2_floor = llvm::PowerOf2Floor(size); |
1299 | size_t power2_divison = |
1300 | power2_floor >> (63 - llvm::countLeadingZeros(divisions)); |
1301 | if (C10_UNLIKELY(power2_divison == 0)) { |
1302 | return (power2_floor << 1); |
1303 | } |
1304 | size_t round_size_floor = size & (~(power2_divison - 1)); |
1305 | return (round_size_floor == size) ? size |
1306 | : round_size_floor + power2_divison; |
1307 | } |
1308 | |
1309 | static size_t round_size(size_t size) { |
1310 | if (size < kMinBlockSize) { |
1311 | return kMinBlockSize; |
1312 | } else { |
1313 | auto divisions = CachingAllocatorConfig::roundup_power2_divisions(size); |
1314 | if (divisions > 0 && size > (kMinBlockSize * divisions)) { |
1315 | return roundup_power2_next_division(size, divisions); |
1316 | } else { |
1317 | return kMinBlockSize * ((size + kMinBlockSize - 1) / kMinBlockSize); |
1318 | } |
1319 | } |
1320 | } |
1321 | |
1322 | // See Note [Interaction with CUDA graph capture] |
1323 | |
1324 | // Called by CUDAGraph::capture_begin |
1325 | void notifyCaptureBegin(CaptureId_t graph_id, MempoolId_t mempool_id) { |
1326 | std::lock_guard<std::recursive_mutex> lock(mutex); |
1327 | captures_underway++; |
1328 | auto it = graph_pools.find(mempool_id); |
1329 | if (it == graph_pools.end()) { |
1330 | // mempool_id does not reference an existing pool. Make a new pool for |
1331 | // this capture. |
1332 | graph_pools.emplace(mempool_id, std::make_unique<PrivatePool>()); |
1333 | } else { |
1334 | // mempool_id references an existing pool, which the current capture will |
1335 | // share. Check this pool is live (at least one other capture already |
1336 | // references it). |
1337 | TORCH_INTERNAL_ASSERT(it->second->use_count > 0); |
1338 | it->second->use_count++; |
1339 | } |
1340 | // Maps this graph_id to mempool_id and makes sure this graph_id wasn't |
1341 | // somehow assigned a mempool_id already. Keeps essential effect (insert) |
1342 | // out of macro. |
1343 | bool inserted = capture_to_pool_map.insert({graph_id, mempool_id}).second; |
1344 | TORCH_INTERNAL_ASSERT(inserted); |
1345 | } |
1346 | |
1347 | // Called by CUDAGraph::capture_end |
1348 | void notifyCaptureAboutToEnd(CaptureId_t graph_id) { |
1349 | std::lock_guard<std::recursive_mutex> lock(mutex); |
1350 | captures_underway--; |
1351 | auto it = capture_to_pool_map.find(graph_id); |
1352 | TORCH_INTERNAL_ASSERT(it != capture_to_pool_map.end()); |
1353 | capture_to_pool_map.erase(it); |
1354 | } |
1355 | |
1356 | // Called by CUDAGraph::reset |
1357 | void notifyCaptureDestroy(MempoolId_t mempool_id) { |
1358 | std::lock_guard<std::recursive_mutex> lock(mutex); |
1359 | // The instantiated cudaGraphExec_t has been destroyed. We can't blindly |
1360 | // delete and cudaFree the mempool its capture used, because |
1361 | // 1. other graph(s) might share the same pool |
1362 | // 2. the user might still hold references to output tensors allocated |
1363 | // during capture. |
1364 | // To handle 1 and 2, we track the number of graphs using this particular |
1365 | // mempool. When the count reaches 0, we tell free_cached_blocks it may now |
1366 | // cudaFree blocks from this graph's pool when it discovers they're unused |
1367 | // (unsplit). |
1368 | auto it = graph_pools.find(mempool_id); |
1369 | TORCH_INTERNAL_ASSERT(it != graph_pools.end()); |
1370 | auto uc = --(it->second->use_count); |
1371 | TORCH_INTERNAL_ASSERT(uc >= 0); |
1372 | if (uc == 0) { |
1373 | // Allows free_cached_blocks to begin cudaFreeing this pool's memory, |
1374 | // and makes sure this pool wasn't somehow made freeable already. |
1375 | bool inserted = |
1376 | graph_pools_freeable.insert({mempool_id, it->second.get()}).second; |
1377 | TORCH_INTERNAL_ASSERT(inserted); |
1378 | } |
1379 | } |
1380 | |
1381 | private: |
1382 | // All private methods do not acquire the allocator mutex. |
1383 | |
1384 | std::vector<const Block*> get_all_blocks() const { |
1385 | std::vector<const Block*> blocks; |
1386 | blocks.insert( |
1387 | blocks.end(), small_blocks.blocks.begin(), small_blocks.blocks.end()); |
1388 | blocks.insert( |
1389 | blocks.end(), large_blocks.blocks.begin(), large_blocks.blocks.end()); |
1390 | for (const auto& gp : graph_pools) { |
1391 | blocks.insert( |
1392 | blocks.end(), |
1393 | gp.second->small_blocks.blocks.begin(), |
1394 | gp.second->small_blocks.blocks.end()); |
1395 | blocks.insert( |
1396 | blocks.end(), |
1397 | gp.second->large_blocks.blocks.begin(), |
1398 | gp.second->large_blocks.blocks.end()); |
1399 | } |
1400 | blocks.insert(blocks.end(), active_blocks.begin(), active_blocks.end()); |
1401 | return blocks; |
1402 | } |
1403 | |
1404 | /** moves a block into a pool of cached free blocks */ |
1405 | void free_block(Block* block) { |
1406 | TORCH_INTERNAL_ASSERT( |
1407 | !block->allocated && block->event_count == 0 && |
1408 | block->stream_uses.empty()); |
1409 | if (block->history) { |
1410 | record_trace( |
1411 | TraceEntry::FREE_COMPLETED, |
1412 | int64_t(block->ptr), |
1413 | block->history->h.real_size, |
1414 | block->stream, |
1415 | block->history->h.context); |
1416 | } |
1417 | size_t original_block_size = block->size; |
1418 | size_t requested_size = block->requested_size; |
1419 | |
1420 | auto& pool = *block->pool; |
1421 | int64_t net_change_inactive_split_blocks = 0; |
1422 | int64_t net_change_inactive_split_size = 0; |
1423 | |
1424 | const std::array<Block*, 2> merge_candidates = {block->prev, block->next}; |
1425 | for (Block* merge_candidate : merge_candidates) { |
1426 | const int64_t subsumed_size = |
1427 | try_merge_blocks(block, merge_candidate, pool); |
1428 | if (subsumed_size > 0) { |
1429 | net_change_inactive_split_blocks -= 1; |
1430 | net_change_inactive_split_size -= subsumed_size; |
1431 | } |
1432 | } |
1433 | |
1434 | active_blocks.erase(block); |
1435 | // Makes sure the Block* isn't already present in the pool we're freeing it |
1436 | // back into. |
1437 | bool inserted = pool.blocks.insert(block).second; |
1438 | TORCH_INTERNAL_ASSERT(inserted); |
1439 | |
1440 | if (block->is_split()) { |
1441 | net_change_inactive_split_blocks += 1; |
1442 | net_change_inactive_split_size += block->size; |
1443 | } |
1444 | |
1445 | StatTypes stat_types = {false}; |
1446 | stat_types[static_cast<size_t>(StatType::AGGREGATE)] = true; |
1447 | stat_types[static_cast<size_t>(get_stat_type_for_pool(pool))] = true; |
1448 | for_each_selected_stat_type(stat_types, [&](size_t stat_type) { |
1449 | update_stat( |
1450 | stats.inactive_split[stat_type], net_change_inactive_split_blocks); |
1451 | update_stat( |
1452 | stats.inactive_split_bytes[stat_type], |
1453 | net_change_inactive_split_size); |
1454 | update_stat(stats.active[stat_type], -1); |
1455 | update_stat( |
1456 | stats.active_bytes[stat_type], |
1457 | -static_cast<std::int64_t>(original_block_size)); |
1458 | update_stat( |
1459 | stats.requested_bytes[stat_type], |
1460 | -static_cast<std::int64_t>(requested_size)); |
1461 | }); |
1462 | } |
1463 | |
1464 | /** combine previously split blocks. returns the size of the subsumed block, |
1465 | * or 0 on failure. */ |
1466 | size_t try_merge_blocks(Block* dst, Block* src, BlockPool& pool) { |
1467 | if (!src || src->allocated || src->event_count > 0 || |
1468 | !src->stream_uses.empty()) { |
1469 | return 0; |
1470 | } |
1471 | |
1472 | AT_ASSERT(dst->is_split() && src->is_split()); |
1473 | |
1474 | if (dst->prev == src) { // [src dst] |
1475 | dst->ptr = src->ptr; |
1476 | dst->prev = src->prev; |
1477 | if (dst->prev) { |
1478 | dst->prev->next = dst; |
1479 | } |
1480 | if (!dst->history) { |
1481 | dst->history = std::move(src->history); |
1482 | dst->history_last = src->history_last; |
1483 | } else if (src->history) { |
1484 | src->history_last->next = std::move(dst->history); |
1485 | dst->history = std::move(src->history); |
1486 | } |
1487 | src->history_last = nullptr; |
1488 | } else { // [dest src] |
1489 | dst->next = src->next; |
1490 | if (dst->next) { |
1491 | dst->next->prev = dst; |
1492 | } |
1493 | |
1494 | if (!dst->history) { |
1495 | dst->history = std::move(src->history); |
1496 | dst->history_last = src->history_last; |
1497 | } else if (src->history) { |
1498 | dst->history_last->next = std::move(src->history); |
1499 | dst->history_last = src->history_last; |
1500 | } |
1501 | src->history_last = nullptr; |
1502 | } |
1503 | const size_t subsumed_size = src->size; |
1504 | dst->size += subsumed_size; |
1505 | auto erased = pool.blocks.erase(src); |
1506 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(erased == 1); |
1507 | delete src; |
1508 | |
1509 | return subsumed_size; |
1510 | } |
1511 | |
1512 | BlockPool& get_pool(size_t size, cudaStream_t stream) { |
1513 | #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 |
1514 | // captures_underway is a conservative guess that the current stream may be |
1515 | // capturing. It's only > 0 if some thread has begun and not yet ended a |
1516 | // capture, so it's usually 0, and we can short-circuit |
1517 | // cudaStreamCaptureStatus (which does a TLS lookup). |
1518 | if (C10_UNLIKELY(captures_underway)) { |
1519 | CaptureId_t id; |
1520 | cudaStreamCaptureStatus status; |
1521 | C10_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, &id)); |
1522 | if (status != cudaStreamCaptureStatus::cudaStreamCaptureStatusNone) { |
1523 | TORCH_INTERNAL_ASSERT( |
1524 | status != |
1525 | cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated); |
1526 | // Retrieves the private pool assigned to this capture. |
1527 | auto it0 = capture_to_pool_map.find(id); |
1528 | TORCH_INTERNAL_ASSERT(it0 != capture_to_pool_map.end()); |
1529 | auto it1 = graph_pools.find(it0->second); |
1530 | TORCH_INTERNAL_ASSERT(it1 != graph_pools.end()); |
1531 | if (size <= kSmallSize) { |
1532 | return it1->second->small_blocks; |
1533 | } else { |
1534 | return it1->second->large_blocks; |
1535 | } |
1536 | } |
1537 | } |
1538 | #endif |
1539 | if (size <= kSmallSize) { |
1540 | return small_blocks; |
1541 | } else { |
1542 | return large_blocks; |
1543 | } |
1544 | } |
1545 | |
1546 | StatType get_stat_type_for_pool(const BlockPool& pool) { |
1547 | return pool.is_small ? StatType::SMALL_POOL : StatType::LARGE_POOL; |
1548 | } |
1549 | |
1550 | bool should_split(const Block* block, size_t size) { |
1551 | size_t remaining = block->size - size; |
1552 | if (block->pool->is_small) { |
1553 | return remaining >= kMinBlockSize; |
1554 | } else { |
1555 | return (size < CachingAllocatorConfig::max_split_size()) && |
1556 | (remaining > kSmallSize); |
1557 | } |
1558 | } |
1559 | |
1560 | static size_t get_allocation_size(size_t size) { |
1561 | if (size <= kSmallSize) { |
1562 | return kSmallBuffer; |
1563 | } else if (size < kMinLargeAlloc) { |
1564 | return kLargeBuffer; |
1565 | } else { |
1566 | return kRoundLarge * ((size + kRoundLarge - 1) / kRoundLarge); |
1567 | } |
1568 | } |
1569 | |
1570 | bool get_free_block(AllocParams& p) { |
1571 | BlockPool& pool = *p.pool; |
1572 | |
1573 | if (C10_UNLIKELY( |
1574 | set_fraction && |
1575 | CachingAllocatorConfig::garbage_collection_threshold() > 0.0)) { |
1576 | // Track block reuse interval only when garbage collection is enabled. |
1577 | for (auto& b : pool.blocks) { |
1578 | ++b->gc_count; |
1579 | } |
1580 | } |
1581 | auto it = pool.blocks.lower_bound(&p.search_key); |
1582 | if (it == pool.blocks.end() || (*it)->stream != p.stream()) |
1583 | return false; |
1584 | // Do not return an oversized block for a large request |
1585 | if ((p.size() < CachingAllocatorConfig::max_split_size()) && |
1586 | ((*it)->size >= CachingAllocatorConfig::max_split_size())) |
1587 | return false; |
1588 | // Allow oversized block size to be rounded up but within a limit |
1589 | if ((p.size() >= CachingAllocatorConfig::max_split_size()) && |
1590 | ((*it)->size >= p.size() + kLargeBuffer)) |
1591 | return false; |
1592 | p.block = *it; |
1593 | (*it)->gc_count = 0; // Denote this block has been used |
1594 | pool.blocks.erase(it); |
1595 | return true; |
1596 | } |
1597 | |
1598 | bool trigger_free_memory_callbacks(AllocParams& p) { |
1599 | bool freed_memory = false; |
1600 | for (const auto& name : FreeCudaMemoryCallbacksRegistry()->Keys()) { |
1601 | freed_memory |= |
1602 | FreeCudaMemoryCallbacksRegistry()->Create(name)->Execute(); |
1603 | } |
1604 | return freed_memory; |
1605 | } |
1606 | |
1607 | void garbage_collect_cached_blocks() { |
1608 | // Free unused cached blocks to reclaim GPU memory. |
1609 | // Unlike release_cached_blocks(), this does not enforce synchronization and |
1610 | // therefore should be of less overheads. |
1611 | |
1612 | size_t gc_threshold = static_cast<size_t>( |
1613 | CachingAllocatorConfig::garbage_collection_threshold() * |
1614 | allowed_memory_maximum); |
1615 | // No need to trigger GC yet |
1616 | if (total_allocated_memory <= gc_threshold) { |
1617 | return; |
1618 | } |
1619 | const auto target_size = total_allocated_memory - gc_threshold; |
1620 | size_t gc_reclaimed = 0; |
1621 | |
1622 | // Calculate the total age of the free-able blocks. We'll use it later to |
1623 | // get "avg age" threshold. |
1624 | double total_age = 0.0; |
1625 | int freeable_block_count = 0; |
1626 | for (auto& b : large_blocks.blocks) { |
1627 | if (!b->is_split()) { |
1628 | total_age += b->gc_count; |
1629 | ++freeable_block_count; |
1630 | } |
1631 | } |
1632 | // No free-able blocks? |
1633 | if (freeable_block_count == 0) { |
1634 | return; |
1635 | } |
1636 | |
1637 | // Repeat GC until we reach reclaim > target size. |
1638 | bool block_freed = true; |
1639 | while (gc_reclaimed < target_size && block_freed == true && |
1640 | freeable_block_count > 0) { |
1641 | // Free blocks exceeding this age threshold first. |
1642 | double age_threshold = total_age / freeable_block_count; |
1643 | // Stop iteration if we can no longer free a block. |
1644 | block_freed = false; |
1645 | |
1646 | // Free blocks of > avg age. Don't stop upon reaching the target_size, |
1647 | // we don't want this GC to be triggered frequently. |
1648 | auto it = large_blocks.blocks.begin(); |
1649 | while (it != large_blocks.blocks.end()) { |
1650 | Block* block = *it; |
1651 | ++it; |
1652 | if (!block->is_split() && block->gc_count >= age_threshold) { |
1653 | block_freed = true; |
1654 | gc_reclaimed += block->size; |
1655 | total_age -= block->gc_count; // Decrement the age |
1656 | freeable_block_count--; // One less block that can be freed |
1657 | release_block(block); |
1658 | } |
1659 | } |
1660 | } |
1661 | } |
1662 | |
1663 | bool alloc_block(AllocParams& p, bool isRetry) { |
1664 | // Defensively checks for preexisting CUDA error state. |
1665 | C10_CUDA_CHECK(cudaGetLastError()); |
1666 | |
1667 | size_t size = p.alloc_size; |
1668 | void* ptr; |
1669 | |
1670 | if (isRetry) { |
1671 | stats.num_alloc_retries += 1; |
1672 | } |
1673 | |
1674 | if (set_fraction && |
1675 | total_allocated_memory + size > allowed_memory_maximum) { |
1676 | p.err = cudaErrorMemoryAllocation; |
1677 | return false; |
1678 | } else { |
1679 | p.err = cudaMallocMaybeCapturing(&ptr, size); |
1680 | if (p.err != cudaSuccess) { |
1681 | if (p.err == cudaErrorMemoryAllocation) { |
1682 | // If this is the first attempt (!isRetry), we can forgive and clear |
1683 | // CUDA's internal error state. |
1684 | // |
1685 | // If this is the second attempt (isRetry), malloc's TORCH_CHECK_WITH |
1686 | // will take over to throw a helpful exception. The user can choose |
1687 | // to catch the exception, free some stuff in their script, and |
1688 | // attempt the allocation again. In this case, we can also forgive and |
1689 | // clear CUDA's internal error state. |
1690 | cudaGetLastError(); |
1691 | } else { |
1692 | // If the error's unrelated to memory allocation, we should throw |
1693 | // immediately. |
1694 | C10_CUDA_CHECK(p.err); |
1695 | } |
1696 | return false; |
1697 | } |
1698 | } |
1699 | |
1700 | if (p.pool->owner_PrivatePool) { |
1701 | // The block is for a CUDA graph's PrivatePool. |
1702 | p.pool->owner_PrivatePool->cudaMalloc_count++; |
1703 | } |
1704 | |
1705 | total_allocated_memory += size; |
1706 | p.block = new Block(p.device(), p.stream(), size, p.pool, (char*)ptr); |
1707 | for_each_selected_stat_type(p.stat_types, [&](size_t stat_type) { |
1708 | update_stat(stats.segment[stat_type], 1); |
1709 | update_stat(stats.reserved_bytes[stat_type], size); |
1710 | }); |
1711 | if (size >= CachingAllocatorConfig::max_split_size()) |
1712 | update_stat(stats.oversize_segments, 1); |
1713 | |
1714 | // p.block came from new, not cudaMalloc. It should not be nullptr here. |
1715 | TORCH_INTERNAL_ASSERT(p.block != nullptr && p.block->ptr != nullptr); |
1716 | return true; |
1717 | } |
1718 | |
1719 | /** Free one or more oversize blocks to the system allocator. But only enough |
1720 | * **/ |
1721 | /** to satisfy the target size **/ |
1722 | bool release_available_cached_blocks(const AllocParams& p) { |
1723 | if (CachingAllocatorConfig::max_split_size() == |
1724 | std::numeric_limits<size_t>::max()) |
1725 | return false; |
1726 | BlockPool& pool = *p.pool; |
1727 | |
1728 | // because of std::unique_ptr, block cannot be trivially copied |
1729 | Block key( |
1730 | p.search_key.device, |
1731 | p.search_key.stream, |
1732 | p.search_key.size, |
1733 | p.search_key.pool, |
1734 | p.search_key.ptr); |
1735 | key.size = (key.size < CachingAllocatorConfig::max_split_size()) |
1736 | ? CachingAllocatorConfig::max_split_size() |
1737 | : key.size; |
1738 | auto it = pool.blocks.lower_bound(&key); |
1739 | if (it == pool.blocks.end() || (*it)->stream != p.stream()) { |
1740 | // No single block is large enough; free multiple oversize blocks, |
1741 | // starting with the largest |
1742 | if (it == pool.blocks.begin()) |
1743 | return false; |
1744 | size_t totalReleased = 0; |
1745 | --it; // Back up one item. Now on the largest block for the correct |
1746 | // stream |
1747 | while ((totalReleased < key.size) && |
1748 | ((*it)->size >= CachingAllocatorConfig::max_split_size()) && |
1749 | ((*it)->stream == p.stream())) { |
1750 | auto cur = it; |
1751 | totalReleased += (*it)->size; |
1752 | if (it != pool.blocks.begin()) { |
1753 | --it; |
1754 | release_block(*cur); |
1755 | } else { |
1756 | release_block(*cur); |
1757 | break; |
1758 | } |
1759 | } |
1760 | if (totalReleased < key.size) |
1761 | return false; |
1762 | } else { |
1763 | release_block(*it); |
1764 | } |
1765 | return true; |
1766 | } |
1767 | |
1768 | bool release_cached_blocks() { |
1769 | // First ensure that all blocks that can't currently be allocated due to |
1770 | // outstanding events are returned to the pool. |
1771 | synchronize_and_free_events(); |
1772 | |
1773 | // Free all non-split cached blocks to system allocator |
1774 | release_blocks(large_blocks); |
1775 | release_blocks(small_blocks); |
1776 | |
1777 | for (auto it = graph_pools_freeable.begin(); |
1778 | it != graph_pools_freeable.end();) { |
1779 | // See notifyCaptureDestroy for the strategy here. |
1780 | TORCH_INTERNAL_ASSERT(it->second->use_count == 0); |
1781 | release_blocks(it->second->small_blocks); |
1782 | release_blocks(it->second->large_blocks); |
1783 | if (it->second->cudaMalloc_count == 0) { |
1784 | auto erase_count = graph_pools.erase(it->first); |
1785 | TORCH_INTERNAL_ASSERT(erase_count == 1); |
1786 | it = graph_pools_freeable.erase(it); |
1787 | } else { |
1788 | ++it; |
1789 | } |
1790 | } |
1791 | |
1792 | return true; |
1793 | } |
1794 | |
1795 | void release_block(Block* block) { |
1796 | C10_CUDA_CHECK(cudaFree((void*)block->ptr)); |
1797 | total_allocated_memory -= block->size; |
1798 | |
1799 | auto* pool = block->pool; |
1800 | if (pool->owner_PrivatePool) { |
1801 | // The cudaFreed block belonged to a CUDA graph's PrivatePool. |
1802 | TORCH_INTERNAL_ASSERT(pool->owner_PrivatePool->cudaMalloc_count > 0); |
1803 | pool->owner_PrivatePool->cudaMalloc_count--; |
1804 | } |
1805 | |
1806 | StatTypes stat_types = {false}; |
1807 | stat_types[static_cast<size_t>(StatType::AGGREGATE)] = true; |
1808 | stat_types[static_cast<size_t>(get_stat_type_for_pool(*pool))] = true; |
1809 | for_each_selected_stat_type(stat_types, [&](size_t stat_type) { |
1810 | update_stat(stats.segment[stat_type], -1); |
1811 | update_stat( |
1812 | stats.reserved_bytes[stat_type], |
1813 | -static_cast<std::int64_t>(block->size)); |
1814 | }); |
1815 | if (block->size >= CachingAllocatorConfig::max_split_size()) |
1816 | update_stat(stats.oversize_segments, -1); |
1817 | if (block->history) { |
1818 | record_trace( |
1819 | TraceEntry::SEGMENT_FREE, |
1820 | int64_t(block->ptr), |
1821 | block->size, |
1822 | block->stream, |
1823 | block->history->h.context); |
1824 | } |
1825 | pool->blocks.erase(block); |
1826 | delete block; |
1827 | } |
1828 | |
1829 | void release_blocks(BlockPool& pool) { |
1830 | // Frees all non-split blocks |
1831 | auto it = pool.blocks.begin(); |
1832 | while (it != pool.blocks.end()) { |
1833 | Block* block = *it; |
1834 | ++it; |
1835 | if (!block->prev && !block->next) { |
1836 | release_block(block); |
1837 | } |
1838 | } |
1839 | } |
1840 | |
1841 | EventPool::Event create_event_internal(int idx) { |
1842 | // Leak the event pool to avoid shutdown issues. |
1843 | static auto* event_pool = new EventPool(); |
1844 | return event_pool->get(idx); |
1845 | } |
1846 | |
1847 | void synchronize_and_free_events() { |
1848 | // Synchronize on outstanding events and then free associated blocks. |
1849 | |
1850 | // This function syncs, so capture should not be underway. Might as well |
1851 | // make sure capture-deferred end of life events get processed too. |
1852 | TORCH_INTERNAL_ASSERT(captures_underway == 0); |
1853 | insert_events_deferred_until_no_capture(); |
1854 | |
1855 | for (auto& st : cuda_events) { |
1856 | for (auto& e : st.second) { |
1857 | EventPool::Event event = std::move(e.first); |
1858 | Block* block = e.second; |
1859 | |
1860 | C10_CUDA_CHECK(cudaEventSynchronize(*event)); |
1861 | |
1862 | block->event_count--; |
1863 | if (block->event_count == 0) { |
1864 | free_block(block); |
1865 | } |
1866 | } |
1867 | } |
1868 | |
1869 | cuda_events.clear(); |
1870 | } |
1871 | |
1872 | void insert_events(Block* block) { |
1873 | int prev_device; |
1874 | C10_CUDA_CHECK(cudaGetDevice(&prev_device)); |
1875 | |
1876 | stream_set streams(std::move(block->stream_uses)); |
1877 | AT_ASSERT(block->stream_uses.empty()); |
1878 | for (auto& stream : streams) { |
1879 | C10_CUDA_CHECK(cudaSetDevice(stream.device_index())); |
1880 | |
1881 | EventPool::Event event = |
1882 | create_event_internal(static_cast<int>(stream.device_index())); |
1883 | C10_CUDA_CHECK(cudaEventRecord(*event, stream.stream())); |
1884 | |
1885 | block->event_count++; |
1886 | cuda_events[stream].emplace_back(std::move(event), block); |
1887 | } |
1888 | |
1889 | C10_CUDA_CHECK(cudaSetDevice(prev_device)); |
1890 | } |
1891 | |
1892 | void insert_events_deferred_until_no_capture() { |
1893 | if (C10_UNLIKELY(needs_events_deferred_until_no_capture.size() > 0)) { |
1894 | for (auto* block : needs_events_deferred_until_no_capture) { |
1895 | TORCH_INTERNAL_ASSERT(!block->stream_uses.empty()); |
1896 | insert_events(block); |
1897 | } |
1898 | needs_events_deferred_until_no_capture.clear(); |
1899 | } |
1900 | } |
1901 | |
1902 | void process_events() { |
1903 | insert_events_deferred_until_no_capture(); |
1904 | |
1905 | // Process outstanding cudaEvents. Events that are completed are |
1906 | // removed from the queue, and the 'event_count' for the |
1907 | // corresponding allocation is decremented. We maintain a separate |
1908 | // list of events per stream to avoid head-of-line delays if one |
1909 | // or more streams has long-running operations. |
1910 | |
1911 | // Iterate over different streams. |
1912 | for (auto it = cuda_events.begin(); it != cuda_events.end();) { |
1913 | // Iterate over this stream's (event, block) pairs. |
1914 | while (!it->second.empty()) { |
1915 | auto& e = it->second.front(); |
1916 | EventPool::Event event = std::move(e.first); |
1917 | Block* block = e.second; |
1918 | |
1919 | cudaError_t err = C10_CUDA_ERROR_HANDLED(cudaEventQuery(*event)); |
1920 | if (err == cudaErrorNotReady) { |
1921 | // ignore and clear the error if not ready |
1922 | cudaGetLastError(); |
1923 | // Return the ownership of the Event (unique ptr) |
1924 | e.first = std::move(event); |
1925 | break; |
1926 | } else if (err != cudaSuccess) { |
1927 | C10_CUDA_CHECK(err); |
1928 | } |
1929 | |
1930 | block->event_count--; |
1931 | if (block->event_count == 0) { |
1932 | free_block(block); |
1933 | } |
1934 | it->second.pop_front(); |
1935 | } |
1936 | |
1937 | if (it->second.empty()) { |
1938 | it = cuda_events.erase(it); |
1939 | } else { |
1940 | it++; |
1941 | } |
1942 | } |
1943 | } |
1944 | |
1945 | // Iterates over sizes of all memory blocks for given device in given pool |
1946 | void cache_info_aux(const BlockPool& pool, size_t* largest) { |
1947 | for (const auto& block : pool.blocks) { |
1948 | const auto blocksize = block->size; |
1949 | if (blocksize > *largest) { |
1950 | *largest = blocksize; |
1951 | } |
1952 | } |
1953 | } |
1954 | |
1955 | void record_trace( |
1956 | TraceEntry::Action action, |
1957 | int64_t addr, |
1958 | size_t size, |
1959 | cudaStream_t stream, |
1960 | std::shared_ptr<Context> context) { |
1961 | auto te = TraceEntry( |
1962 | action, |
1963 | addr, |
1964 | size, |
1965 | stream, |
1966 | alloc_trace_record_context_ ? std::move(context) : nullptr); |
1967 | if (alloc_trace->size() < alloc_trace_max_entries_) { |
1968 | alloc_trace->emplace_back(te); |
1969 | } else { |
1970 | (*alloc_trace)[alloc_trace_next++] = te; |
1971 | if (alloc_trace_next == alloc_trace_max_entries_) { |
1972 | alloc_trace_next = 0; |
1973 | } |
1974 | } |
1975 | } |
1976 | }; |
1977 | |
1978 | // Returns whether to force all allocations to bypass the caching allocator and |
1979 | // go straight to cudaMalloc. This setting is useful when debugging GPU memory |
1980 | // errors, since the caching allocator foils cuda-memcheck. |
1981 | bool forceUncachedAllocator() { |
1982 | static bool force_uncached = |
1983 | getenv("PYTORCH_NO_CUDA_MEMORY_CACHING" ) != nullptr; |
1984 | return force_uncached; |
1985 | } |
1986 | |
1987 | static void uncached_delete(void* ptr) { |
1988 | const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); |
1989 | if (C10_UNLIKELY(interp)) { |
1990 | (*interp)->trace_gpu_memory_deallocation(reinterpret_cast<uintptr_t>(ptr)); |
1991 | } |
1992 | C10_CUDA_CHECK(cudaFree(ptr)); |
1993 | } |
1994 | |
1995 | void local_raw_delete(void* ptr); |
1996 | |
1997 | class NativeCachingAllocator : public CUDAAllocator { |
1998 | private: |
1999 | std::mutex mutex; |
2000 | |
2001 | // allocated blocks by device pointer |
2002 | ska::flat_hash_map<void*, Block*> allocated_blocks; |
2003 | |
2004 | void add_allocated_block(Block* block) { |
2005 | std::lock_guard<std::mutex> lock(mutex); |
2006 | allocated_blocks[block->ptr] = block; |
2007 | } |
2008 | |
2009 | public: |
2010 | std::vector<std::unique_ptr<DeviceCachingAllocator>> device_allocator; |
2011 | |
2012 | Block* get_allocated_block(void* ptr, bool remove = false) { |
2013 | std::lock_guard<std::mutex> lock(mutex); |
2014 | auto it = allocated_blocks.find(ptr); |
2015 | if (it == allocated_blocks.end()) { |
2016 | return nullptr; |
2017 | } |
2018 | Block* block = it->second; |
2019 | if (remove) { |
2020 | allocated_blocks.erase(it); |
2021 | } |
2022 | return block; |
2023 | } |
2024 | |
2025 | void init(int device_count) override { |
2026 | const auto size = static_cast<int64_t>(device_allocator.size()); |
2027 | if (size < device_count) { |
2028 | device_allocator.resize(device_count); |
2029 | for (const auto i : c10::irange(size, device_count)) { |
2030 | device_allocator[i] = std::make_unique<DeviceCachingAllocator>(); |
2031 | } |
2032 | } |
2033 | } |
2034 | |
2035 | bool initialized() override { |
2036 | return device_allocator.size() > 0; |
2037 | } |
2038 | |
2039 | /** allocates a block which is safe to use from the provided stream */ |
2040 | void malloc(void** devPtr, int device, size_t size, cudaStream_t stream) { |
2041 | TORCH_INTERNAL_ASSERT( |
2042 | 0 <= device && static_cast<size_t>(device) < device_allocator.size(), |
2043 | "Allocator not initialized for device " , |
2044 | device, |
2045 | ": did you call init?" ); |
2046 | Block* block = device_allocator[device]->malloc(device, size, stream); |
2047 | add_allocated_block(block); |
2048 | *devPtr = (void*)block->ptr; |
2049 | const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); |
2050 | if (C10_UNLIKELY(interp)) { |
2051 | (*interp)->trace_gpu_memory_allocation( |
2052 | reinterpret_cast<uintptr_t>(*devPtr)); |
2053 | } |
2054 | } |
2055 | |
2056 | void free(void* ptr) { |
2057 | if (!ptr) { |
2058 | return; |
2059 | } |
2060 | Block* block = get_allocated_block(ptr, true /* remove */); |
2061 | if (!block) { |
2062 | TORCH_CHECK(false, "invalid device pointer: " , ptr); |
2063 | } |
2064 | const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); |
2065 | if (C10_UNLIKELY(interp)) { |
2066 | (*interp)->trace_gpu_memory_deallocation( |
2067 | reinterpret_cast<uintptr_t>(block->ptr)); |
2068 | } |
2069 | device_allocator[block->device]->free(block); |
2070 | } |
2071 | |
2072 | void setMemoryFraction(double fraction, int device) override { |
2073 | TORCH_INTERNAL_ASSERT( |
2074 | 0 <= device && static_cast<size_t>(device) < device_allocator.size(), |
2075 | "Allocator not initialized for device " , |
2076 | device, |
2077 | ": did you call init?" ); |
2078 | TORCH_INTERNAL_ASSERT( |
2079 | 0 <= fraction && fraction <= 1, |
2080 | "invalid fraction:" , |
2081 | fraction, |
2082 | ". Please set within (0, 1)." ); |
2083 | int activated_device; |
2084 | C10_CUDA_CHECK(cudaGetDevice(&activated_device)); |
2085 | if (activated_device != device) { |
2086 | C10_CUDA_CHECK(cudaSetDevice(device)); |
2087 | } |
2088 | device_allocator[device]->setMemoryFraction(fraction); |
2089 | } |
2090 | |
2091 | void recordHistory( |
2092 | bool enabled, |
2093 | CreateContextFn context_recorder, |
2094 | size_t alloc_trace_max_entries, |
2095 | bool alloc_trace_record_context) override { |
2096 | int device; |
2097 | C10_CUDA_CHECK(cudaGetDevice(&device)); |
2098 | device_allocator[device]->recordHistory( |
2099 | enabled, |
2100 | std::move(context_recorder), |
2101 | alloc_trace_max_entries, |
2102 | alloc_trace_record_context); |
2103 | } |
2104 | |
2105 | void attachOutOfMemoryObserver(OutOfMemoryObserver observer) override { |
2106 | int device; |
2107 | C10_CUDA_CHECK(cudaGetDevice(&device)); |
2108 | device_allocator[device]->attachOutOfMemoryObserver(std::move(observer)); |
2109 | } |
2110 | |
2111 | void emptyCache() override { |
2112 | for (auto& da : device_allocator) |
2113 | da->emptyCache(); |
2114 | } |
2115 | |
2116 | void* getBaseAllocation(void* ptr, size_t* outSize) override { |
2117 | Block* block = get_allocated_block(ptr); |
2118 | if (!block) { |
2119 | TORCH_CHECK(false, "invalid device pointer: " , ptr); |
2120 | } |
2121 | return device_allocator[block->device]->getBaseAllocation(block, outSize); |
2122 | } |
2123 | |
2124 | void recordStream(const DataPtr& ptr, cuda::CUDAStream stream) override { |
2125 | // Empty tensor's storage().data() might be a null ptr. As there is no |
2126 | // blocks associated with those tensors, it is fine to do nothing here. |
2127 | if (!ptr.get()) { |
2128 | return; |
2129 | } |
2130 | |
2131 | // If a tensor is not allocated by this instance, simply skip |
2132 | // This usually happens when CUDA tensors are shared across processes, |
2133 | // we have implemented reference counting based sharing mechanism to |
2134 | // guarantee tensors won't be accidentally freed by one process while |
2135 | // they are still being used in another |
2136 | if (ptr.get_deleter() != &local_raw_delete) |
2137 | return; |
2138 | |
2139 | Block* block = get_allocated_block(ptr.get()); |
2140 | // block must not be null reaching here |
2141 | TORCH_INTERNAL_ASSERT(block != nullptr, "No allocated block can be found" ); |
2142 | device_allocator[block->device]->recordStream(block, stream); |
2143 | } |
2144 | |
2145 | SnapshotInfo snapshot() override { |
2146 | SnapshotInfo result; |
2147 | for (auto& da : device_allocator) { |
2148 | result.device_traces.emplace_back(da->trace()); |
2149 | auto snap = da->snapshot(); |
2150 | result.segments.insert(result.segments.end(), snap.begin(), snap.end()); |
2151 | } |
2152 | return result; |
2153 | } |
2154 | DataPtr allocate(size_t size) const override { |
2155 | constexpr size_t one_exa_bytes = 1152921504606846976ULL; |
2156 | TORCH_CHECK_WITH( |
2157 | OutOfMemoryError, |
2158 | size < one_exa_bytes, |
2159 | "CUDA out of memory. Tried to allocate more than 1EB memory." ); |
2160 | int device; |
2161 | C10_CUDA_CHECK(cudaGetDevice(&device)); |
2162 | void* r = nullptr; |
2163 | if (forceUncachedAllocator()) { |
2164 | // Deliberately don't use cudaMallocMaybeCapturing here, to force an error |
2165 | // if someone tries to use forceUncachedAllocator while capturing. |
2166 | C10_CUDA_CHECK(cudaMalloc(&r, size)); |
2167 | const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); |
2168 | if (C10_UNLIKELY(interp)) { |
2169 | (*interp)->trace_gpu_memory_allocation(reinterpret_cast<uintptr_t>(r)); |
2170 | } |
2171 | return {r, r, &uncached_delete, Device(DeviceType::CUDA, device)}; |
2172 | } |
2173 | if (size != 0) { |
2174 | // Allocator declars allocate const!? |
2175 | const_cast<NativeCachingAllocator*>(this)->malloc( |
2176 | &r, device, size, cuda::getCurrentCUDAStream(device)); |
2177 | } |
2178 | return {r, r, &local_raw_delete, Device(DeviceType::CUDA, device)}; |
2179 | } |
2180 | DeleterFnPtr raw_deleter() const override { |
2181 | if (forceUncachedAllocator()) { |
2182 | return &uncached_delete; |
2183 | } else { |
2184 | return &local_raw_delete; |
2185 | } |
2186 | } |
2187 | void cacheInfo(int dev_id, size_t* largestBlock) override { |
2188 | device_allocator[dev_id]->cacheInfo(largestBlock); |
2189 | } |
2190 | void assertValidDevice(int device) { |
2191 | const auto device_num = device_allocator.size(); |
2192 | TORCH_CHECK( |
2193 | 0 <= device && device < static_cast<int64_t>(device_num), |
2194 | "Invalid device argument " , |
2195 | device, |
2196 | ": did you call init?" ); |
2197 | } |
2198 | |
2199 | DeviceStats getDeviceStats(int device) override { |
2200 | assertValidDevice(device); |
2201 | return device_allocator[device]->getStats(); |
2202 | } |
2203 | |
2204 | void resetAccumulatedStats(int device) override { |
2205 | assertValidDevice(device); |
2206 | device_allocator[device]->resetAccumulatedStats(); |
2207 | } |
2208 | |
2209 | void resetPeakStats(int device) override { |
2210 | assertValidDevice(device); |
2211 | device_allocator[device]->resetPeakStats(); |
2212 | } |
2213 | // CUDAGraph interactions |
2214 | void notifyCaptureBegin( |
2215 | int device, |
2216 | CaptureId_t graph_id, |
2217 | MempoolId_t mempool_id) override { |
2218 | assertValidDevice(device); |
2219 | device_allocator[device]->notifyCaptureBegin( |
2220 | graph_id, std::move(mempool_id)); |
2221 | } |
2222 | |
2223 | void notifyCaptureAboutToEnd(int device, CaptureId_t graph_id) override { |
2224 | assertValidDevice(device); |
2225 | device_allocator[device]->notifyCaptureAboutToEnd(graph_id); |
2226 | } |
2227 | |
2228 | void notifyCaptureEnded(int device, CaptureId_t graph_id) override {} // no-op |
2229 | |
2230 | void notifyCaptureDestroy(int device, MempoolId_t mempool_id) override { |
2231 | assertValidDevice(device); |
2232 | device_allocator[device]->notifyCaptureDestroy(std::move(mempool_id)); |
2233 | } |
2234 | |
2235 | void* raw_alloc(size_t nbytes) override { |
2236 | if (nbytes == 0) { |
2237 | return nullptr; |
2238 | } |
2239 | int device; |
2240 | C10_CUDA_CHECK(cudaGetDevice(&device)); |
2241 | void* r = nullptr; |
2242 | malloc(&r, device, nbytes, cuda::getCurrentCUDAStream(device)); |
2243 | return r; |
2244 | } |
2245 | |
2246 | void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) override { |
2247 | if (nbytes == 0) { |
2248 | return nullptr; |
2249 | } |
2250 | int device; |
2251 | C10_CUDA_CHECK(cudaGetDevice(&device)); |
2252 | void* r = nullptr; |
2253 | malloc(&r, device, nbytes, stream); |
2254 | return r; |
2255 | } |
2256 | bool needsPoolSpecificPeerAccess() override { |
2257 | return false; |
2258 | } |
2259 | void raw_delete(void* ptr) override { |
2260 | this->free(ptr); |
2261 | } |
2262 | |
2263 | // In CUDA IPC, sender sends a tensor to receiver, getIpcDevPtr |
2264 | // is called by the receiving process to map the CUDA memory from the sending |
2265 | // process into its own address space. |
2266 | // |
2267 | // CUDA IPC only allows sharing a big memory block associated with a |
2268 | // cudaIpcMemHandle_t and it can be opened only **once** per context per |
2269 | // process. There can be multiple types of storage in the same IPC mem block, |
2270 | // so we must cache the device ptr to construct typed storage as it comes. |
2271 | // |
2272 | // ipcMemHandle_to_devptr maps a cudaIpcMemHandle_t to a device pointer in the |
2273 | // process that can be used to access the memory block in the sender process. |
2274 | // It only saves a weak_ptr of the device pointer in the map, the shared_ptr |
2275 | // will be used to reconstruct all storages in this CudaMalloc allocation. And |
2276 | // it will deleted in cudaIpcCloseMemHandle when its reference count is 0. |
2277 | // |
2278 | std::mutex IpcMutex; |
2279 | ska::flat_hash_map<std::string, std::weak_ptr<void>> ipcMemHandle_to_devptr; |
2280 | std::shared_ptr<void> getIpcDevPtr(std::string handle) override { |
2281 | std::lock_guard<std::mutex> lock(IpcMutex); |
2282 | |
2283 | auto iter = ipcMemHandle_to_devptr.find(handle); |
2284 | if (iter != ipcMemHandle_to_devptr.end()) { |
2285 | auto devptr = iter->second.lock(); |
2286 | if (devptr) |
2287 | return devptr; |
2288 | } |
2289 | // This ipcMemHandle hasn't been opened, or already expired, open it to |
2290 | // enable IPC access to that mem block. |
2291 | void* dev = nullptr; |
2292 | auto ipc_handle = |
2293 | reinterpret_cast<const cudaIpcMemHandle_t*>(handle.c_str()); |
2294 | C10_CUDA_CHECK(cudaIpcOpenMemHandle( |
2295 | &dev, *ipc_handle, cudaIpcMemLazyEnablePeerAccess)); |
2296 | // devPtr has to be deleted in same device when created. |
2297 | int curr_device; |
2298 | C10_CUDA_CHECK(cudaGetDevice(&curr_device)); |
2299 | auto sp = |
2300 | std::shared_ptr<void>(dev, [handle, curr_device, this](void* ptr) { |
2301 | cuda::CUDAGuard device_guard(curr_device); |
2302 | std::lock_guard<std::mutex> deleter_lock(IpcMutex); |
2303 | C10_CUDA_CHECK(cudaIpcCloseMemHandle(ptr)); |
2304 | ipcMemHandle_to_devptr.erase(handle); |
2305 | }); |
2306 | std::weak_ptr<void> wp = sp; |
2307 | // To eliminate an additional search, we can use insert(). |
2308 | // It doesn't overwrite when key already exists(ptr expired). |
2309 | // But in the deleter for sp we erased the entry, |
2310 | // this should be safe to do now. |
2311 | ipcMemHandle_to_devptr.insert(iter, {handle, wp}); |
2312 | |
2313 | return sp; |
2314 | } |
2315 | std::string name() override { |
2316 | return "native" ; |
2317 | } |
2318 | }; |
2319 | |
2320 | NativeCachingAllocator allocator; |
2321 | |
2322 | void local_raw_delete(void* ptr) { |
2323 | allocator.free(ptr); |
2324 | } |
2325 | |
2326 | void setAllocatorSettings(const std::string& env) { |
2327 | CachingAllocatorConfig::instance().parseArgs(env.c_str()); |
2328 | } |
2329 | |
2330 | } // namespace Native |
2331 | |
2332 | // General caching allocator utilities |
2333 | void setAllocatorSettings(const std::string& env) { |
2334 | CachingAllocatorConfig::instance().parseArgs(env.c_str()); |
2335 | } |
2336 | |
2337 | // Size pretty-printer |
2338 | inline std::string format_size(uint64_t size) { |
2339 | std::ostringstream os; |
2340 | os.precision(2); |
2341 | os << std::fixed; |
2342 | if (size <= 1024) { |
2343 | os << size << " bytes" ; |
2344 | } else if (size <= 1048576) { |
2345 | os << (size / 1024.0); |
2346 | os << " KiB" ; |
2347 | } else if (size <= 1073741824ULL) { |
2348 | os << size / 1048576.0; |
2349 | os << " MiB" ; |
2350 | } else { |
2351 | os << size / 1073741824.0; |
2352 | os << " GiB" ; |
2353 | } |
2354 | return os.str(); |
2355 | } |
2356 | |
2357 | namespace CudaMallocAsync { |
2358 | // If this is put in its own header file, it gets incorrectly renamed in HIPify. |
2359 | CUDAAllocator* allocator(); |
2360 | |
2361 | } // namespace CudaMallocAsync |
2362 | |
2363 | struct BackendStaticInitializer { |
2364 | // Parses env for backend at load time, duplicating some logic from |
2365 | // CachingAllocatorConfig. CachingAllocatorConfig double-checks it later (at |
2366 | // runtime). Defers verbose exceptions and error checks, including Cuda |
2367 | // version checks, to CachingAllocatorConfig's runtime doublecheck. If this |
2368 | // works, maybe we should move all of CachingAllocatorConfig here? |
2369 | CUDAAllocator* parseEnvForBackend() { |
2370 | const char* val = getenv("PYTORCH_CUDA_ALLOC_CONF" ); |
2371 | if (val != nullptr) { |
2372 | const std::string config(val); |
2373 | |
2374 | std::regex exp("[\\s,]+" ); |
2375 | std::sregex_token_iterator it(config.begin(), config.end(), exp, -1); |
2376 | std::sregex_token_iterator end; |
2377 | std::vector<std::string> options(it, end); |
2378 | |
2379 | for (auto option : options) { |
2380 | std::regex exp2("[:]+" ); |
2381 | std::sregex_token_iterator it2(option.begin(), option.end(), exp2, -1); |
2382 | std::sregex_token_iterator end2; |
2383 | std::vector<std::string> kv(it2, end2); |
2384 | if (kv.size() >= 2) { |
2385 | if (kv[0] == "backend" ) { |
2386 | if (kv[1] == "cudaMallocAsync" ) |
2387 | return CudaMallocAsync::allocator(); |
2388 | if (kv[1] == "native" ) |
2389 | return &Native::allocator; |
2390 | } |
2391 | } |
2392 | } |
2393 | } |
2394 | return &Native::allocator; |
2395 | } |
2396 | |
2397 | BackendStaticInitializer() { |
2398 | auto r = parseEnvForBackend(); |
2399 | allocator.store(r); |
2400 | } |
2401 | }; |
2402 | |
2403 | std::atomic<CUDAAllocator*> allocator{}; |
2404 | BackendStaticInitializer backend_static_initializer; |
2405 | |
2406 | } // namespace CUDACachingAllocator |
2407 | } // namespace cuda |
2408 | } // namespace c10 |
2409 | |