1 | #include <c10/mobile/CPUCachingAllocator.h> |
---|---|
2 | |
3 | #include <c10/core/impl/alloc_cpu.h> |
4 | |
5 | namespace c10 { |
6 | |
7 | namespace { |
8 | thread_local CPUCachingAllocator* caching_allocator_ptr{nullptr}; |
9 | } // namespace |
10 | |
11 | std::mutex CPUCachingAllocator::mutex_; |
12 | ska::flat_hash_map<void*, size_t> CPUCachingAllocator::allocation_map_; |
13 | |
14 | inline void* CPUCachingAllocator::allocate_and_cache(const size_t bytes) { |
15 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
16 | void* ptr; |
17 | try { |
18 | ptr = c10::alloc_cpu(bytes); |
19 | } catch (c10::Error& e) { |
20 | // If allocation fails, try freeing cached available blocks. |
21 | // For now free all available cached blocks. |
22 | free_cached(); |
23 | // Furthermore to consider: If we ever come here running out of memory |
24 | // perhaps it is best to disable caching, since this is likely to happen |
25 | // again. |
26 | // Try again. |
27 | ptr = c10::alloc_cpu(bytes); |
28 | } |
29 | allocation_map_[ptr] = bytes; |
30 | return ptr; |
31 | } |
32 | |
33 | void* CPUCachingAllocator::allocate(const size_t bytes) { |
34 | std::lock_guard<std::mutex> guard(mutex_); |
35 | const auto& it = available_map_.find(bytes); |
36 | if (it == available_map_.end() || it->second.empty()) { |
37 | return allocate_and_cache(bytes); |
38 | } |
39 | return it->second.pop_back_val(); |
40 | } |
41 | |
42 | void CPUCachingAllocator::free(void* ptr) { |
43 | // NB: since we are not really freeing the memory |
44 | // the cases such as quantization code freeing original weights |
45 | // on mobile, will not quite work, as we likely will hold |
46 | // onto that memory. |
47 | // NB: We can also enable max memory cached for better memory |
48 | // management such that free will actually free the memory if |
49 | // we are nearing or above the watermark. |
50 | std::lock_guard<std::mutex> guard(mutex_); |
51 | // If this allocation was done before caching allocator was enabled |
52 | // then free regularly |
53 | const auto& it = allocation_map_.find(ptr); |
54 | if (it == allocation_map_.end()) { |
55 | c10::free_cpu(ptr); |
56 | return; |
57 | } |
58 | const size_t alloc_size = it->second; |
59 | available_map_[alloc_size].push_back(ptr); |
60 | } |
61 | |
62 | void CPUCachingAllocator::record_free(void* ptr) { |
63 | // This function captures the case when the allocated memory |
64 | // is being freed outside the scope of this allocator. |
65 | // At the moment only way to capture this is to have the allocator, |
66 | // that uses this CachingAllocator as the backing allocator, |
67 | // call this function explicitly upon freeing memory while |
68 | // outside the scope of caching allocator. |
69 | // If the memory is freed in some other way, then we will likely |
70 | // have undefined behavior or page fault. But this can be |
71 | // the case without caching allocator as well. |
72 | std::lock_guard<std::mutex> guard(mutex_); |
73 | const auto& it = allocation_map_.find(ptr); |
74 | if (it != allocation_map_.end()) { |
75 | allocation_map_.erase(it); |
76 | } |
77 | } |
78 | |
79 | void CPUCachingAllocator::free_cached() { |
80 | for (const auto& it : available_map_) { |
81 | for (const auto ptr : it.second) { |
82 | c10::free_cpu(ptr); |
83 | // When cached memory is return to OS, it must be removed |
84 | // from allocation_map. |
85 | allocation_map_.erase(ptr); |
86 | } |
87 | } |
88 | available_map_.clear(); |
89 | } |
90 | |
91 | CPUCachingAllocator::~CPUCachingAllocator() { |
92 | free_cached(); |
93 | } |
94 | |
95 | CPUCachingAllocator* GetThreadLocalCachingAllocator() { |
96 | return caching_allocator_ptr; |
97 | } |
98 | |
99 | WithCPUCachingAllocatorGuard::WithCPUCachingAllocatorGuard( |
100 | CPUCachingAllocator* allocator) |
101 | : prev_caching_allocator_ptr_(GetThreadLocalCachingAllocator()) { |
102 | caching_allocator_ptr = allocator; |
103 | } |
104 | |
105 | WithCPUCachingAllocatorGuard::~WithCPUCachingAllocatorGuard() { |
106 | caching_allocator_ptr = prev_caching_allocator_ptr_; |
107 | } |
108 | |
109 | } // namespace c10 |
110 |