1/**
2 * Cache utils in this file is adapted from PyTorch/XLA
3 * https://github.com/pytorch/xla/blob/master/third_party/xla_client/cache.h
4 */
5
6#pragma once
7
8#include <functional>
9#include <list>
10#include <memory>
11#include <mutex>
12#include <unordered_map>
13#include <utility>
14
15namespace torch {
16namespace lazy {
17
18// Generic key and object cache with LRU expiration policy. The objects of type
19// T will be stored as std::shared_ptr<T> and taken and returned as such, by the
20// cache API.
21template <
22 typename K,
23 typename T,
24 typename H = std::hash<K>,
25 typename E = std::equal_to<K>>
26class Cache {
27 public:
28 using TypePtr = std::shared_ptr<T>;
29 using Element = std::pair<K, TypePtr>;
30
31 explicit Cache(size_t max_size) : max_size_(max_size) {}
32
33 // Adds an object to the cache, unless it already exists. If the cache grows
34 // beyond the limit set during construction, the oldest used object will be
35 // removed from the cache.
36 TypePtr Add(K key, TypePtr object) {
37 std::lock_guard<std::mutex> slock(lock_);
38 element_list_.emplace_front(Element(std::move(key), std::move(object)));
39 auto it = element_list_.begin();
40 auto emplace_result = element_map_.emplace(&it->first, it);
41 if (!emplace_result.second) {
42 element_list_.erase(it);
43 DoLRU(emplace_result.first->second);
44 } else if (element_list_.size() > max_size_) {
45 Element* last = &element_list_.back();
46 element_map_.erase(&last->first);
47 element_list_.pop_back();
48 }
49 return emplace_result.first->second->second;
50 }
51
52 // Retrieves the existing object if it exists. If it does, its position in
53 // the LRU list gets moved to the head of the list.
54 // Returns nullptr if no object with the specified key is found within the
55 // cache.
56 TypePtr Get(const K& key) {
57 std::lock_guard<std::mutex> slock(lock_);
58 auto it = element_map_.find(&key);
59 if (it == element_map_.end()) {
60 return nullptr;
61 }
62 DoLRU(it->second);
63 return it->second->second;
64 }
65
66 TypePtr GetLatest() {
67 std::lock_guard<std::mutex> g(lock_);
68 TORCH_CHECK(!element_list_.empty());
69 return element_list_.front().second;
70 }
71
72 bool Erase(const K& key) {
73 std::lock_guard<std::mutex> slock(lock_);
74 auto it = element_map_.find(&key);
75 if (it == element_map_.end()) {
76 return false;
77 }
78 auto lit = it->second;
79 element_map_.erase(it);
80 element_list_.erase(lit);
81 return true;
82 }
83
84 void Clear() {
85 std::lock_guard<std::mutex> slock(lock_);
86 element_map_.clear();
87 element_list_.clear();
88 }
89
90 int Numel() const {
91 std::lock_guard<std::mutex> g(lock_);
92 TORCH_CHECK(element_map_.size() == element_list_.size());
93 return element_map_.size();
94 }
95
96 private:
97 using ElementList = std::list<Element>;
98
99 struct Hasher {
100 size_t operator()(const K* key) const {
101 return hasher(*key);
102 }
103
104 H hasher;
105 };
106
107 struct Equaler {
108 bool operator()(const K* k1, const K* k2) const {
109 return equaler(*k1, *k2);
110 }
111
112 E equaler;
113 };
114
115 using ElementMap = std::
116 unordered_map<const K*, typename ElementList::iterator, Hasher, Equaler>;
117
118 void DoLRU(typename ElementList::iterator it) {
119 element_list_.splice(element_list_.begin(), element_list_, it);
120 }
121
122 mutable std::mutex lock_;
123 size_t max_size_ = 0;
124 ElementList element_list_;
125 ElementMap element_map_;
126};
127
128} // namespace lazy
129} // namespace torch
130