1 | /** |
2 | * Cache utils in this file is adapted from PyTorch/XLA |
3 | * https://github.com/pytorch/xla/blob/master/third_party/xla_client/cache.h |
4 | */ |
5 | |
6 | #pragma once |
7 | |
8 | #include <functional> |
9 | #include <list> |
10 | #include <memory> |
11 | #include <mutex> |
12 | #include <unordered_map> |
13 | #include <utility> |
14 | |
15 | namespace torch { |
16 | namespace lazy { |
17 | |
18 | // Generic key and object cache with LRU expiration policy. The objects of type |
19 | // T will be stored as std::shared_ptr<T> and taken and returned as such, by the |
20 | // cache API. |
21 | template < |
22 | typename K, |
23 | typename T, |
24 | typename H = std::hash<K>, |
25 | typename E = std::equal_to<K>> |
26 | class Cache { |
27 | public: |
28 | using TypePtr = std::shared_ptr<T>; |
29 | using Element = std::pair<K, TypePtr>; |
30 | |
31 | explicit Cache(size_t max_size) : max_size_(max_size) {} |
32 | |
33 | // Adds an object to the cache, unless it already exists. If the cache grows |
34 | // beyond the limit set during construction, the oldest used object will be |
35 | // removed from the cache. |
36 | TypePtr Add(K key, TypePtr object) { |
37 | std::lock_guard<std::mutex> slock(lock_); |
38 | element_list_.emplace_front(Element(std::move(key), std::move(object))); |
39 | auto it = element_list_.begin(); |
40 | auto emplace_result = element_map_.emplace(&it->first, it); |
41 | if (!emplace_result.second) { |
42 | element_list_.erase(it); |
43 | DoLRU(emplace_result.first->second); |
44 | } else if (element_list_.size() > max_size_) { |
45 | Element* last = &element_list_.back(); |
46 | element_map_.erase(&last->first); |
47 | element_list_.pop_back(); |
48 | } |
49 | return emplace_result.first->second->second; |
50 | } |
51 | |
52 | // Retrieves the existing object if it exists. If it does, its position in |
53 | // the LRU list gets moved to the head of the list. |
54 | // Returns nullptr if no object with the specified key is found within the |
55 | // cache. |
56 | TypePtr Get(const K& key) { |
57 | std::lock_guard<std::mutex> slock(lock_); |
58 | auto it = element_map_.find(&key); |
59 | if (it == element_map_.end()) { |
60 | return nullptr; |
61 | } |
62 | DoLRU(it->second); |
63 | return it->second->second; |
64 | } |
65 | |
66 | TypePtr GetLatest() { |
67 | std::lock_guard<std::mutex> g(lock_); |
68 | TORCH_CHECK(!element_list_.empty()); |
69 | return element_list_.front().second; |
70 | } |
71 | |
72 | bool Erase(const K& key) { |
73 | std::lock_guard<std::mutex> slock(lock_); |
74 | auto it = element_map_.find(&key); |
75 | if (it == element_map_.end()) { |
76 | return false; |
77 | } |
78 | auto lit = it->second; |
79 | element_map_.erase(it); |
80 | element_list_.erase(lit); |
81 | return true; |
82 | } |
83 | |
84 | void Clear() { |
85 | std::lock_guard<std::mutex> slock(lock_); |
86 | element_map_.clear(); |
87 | element_list_.clear(); |
88 | } |
89 | |
90 | int Numel() const { |
91 | std::lock_guard<std::mutex> g(lock_); |
92 | TORCH_CHECK(element_map_.size() == element_list_.size()); |
93 | return element_map_.size(); |
94 | } |
95 | |
96 | private: |
97 | using ElementList = std::list<Element>; |
98 | |
99 | struct Hasher { |
100 | size_t operator()(const K* key) const { |
101 | return hasher(*key); |
102 | } |
103 | |
104 | H hasher; |
105 | }; |
106 | |
107 | struct Equaler { |
108 | bool operator()(const K* k1, const K* k2) const { |
109 | return equaler(*k1, *k2); |
110 | } |
111 | |
112 | E equaler; |
113 | }; |
114 | |
115 | using ElementMap = std:: |
116 | unordered_map<const K*, typename ElementList::iterator, Hasher, Equaler>; |
117 | |
118 | void DoLRU(typename ElementList::iterator it) { |
119 | element_list_.splice(element_list_.begin(), element_list_, it); |
120 | } |
121 | |
122 | mutable std::mutex lock_; |
123 | size_t max_size_ = 0; |
124 | ElementList element_list_; |
125 | ElementMap element_map_; |
126 | }; |
127 | |
128 | } // namespace lazy |
129 | } // namespace torch |
130 | |