1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * This source code is licensed under the BSD-style license found in the |
5 | * LICENSE file in the root directory of this source tree. |
6 | */ |
7 | #pragma once |
8 | #include <condition_variable> |
9 | #include <future> |
10 | #include <map> |
11 | |
12 | #if __cplusplus >= 201402L && !defined(__APPLE__) |
13 | // For C++14, use shared_timed_mutex. |
14 | // some macOS C++14 compilers don't support shared_timed_mutex. |
15 | #define FBGEMM_USE_SHARED_TIMED_MUTEX |
16 | #endif |
17 | |
18 | #ifdef FBGEMM_USE_SHARED_TIMED_MUTEX |
19 | #include <shared_mutex> |
20 | #else |
21 | #include <mutex> |
22 | #endif |
23 | |
24 | #ifdef FBCODE_CAFFE2 |
25 | #include <folly/container/F14Map.h> |
26 | #endif |
27 | |
28 | namespace fbgemm { |
29 | |
30 | /** |
31 | * @brief Thread safe cache for microkernels, ensures single creation per key. |
32 | * @tparam KEY Type of unique key (typically a tuple) |
33 | * @tparam VALUE Type of the microkernel function (Typically a function pointer) |
34 | * @tparam THREAD_LOCAL use thread local and avoid locking (default false) |
35 | */ |
36 | template <typename KEY, typename VALUE, bool THREAD_LOCAL = false> |
37 | class CodeCache { |
38 | private: |
39 | #ifdef FBCODE_CAFFE2 |
40 | folly::F14FastMap<KEY, std::shared_future<VALUE>> values_; |
41 | #else |
42 | std::map<KEY, std::shared_future<VALUE>> values_; |
43 | #endif |
44 | |
45 | #ifdef FBGEMM_USE_SHARED_TIMED_MUTEX |
46 | std::shared_timed_mutex mutex_; |
47 | #else |
48 | std::mutex mutex_; |
49 | #endif |
50 | |
51 | public: |
52 | CodeCache(const CodeCache&) = delete; |
53 | CodeCache& operator=(const CodeCache&) = delete; |
54 | |
55 | CodeCache() {} |
56 | |
57 | template <typename GENFUNC> |
58 | VALUE getOrCreate(const KEY& key, GENFUNC generatorFunction) { |
59 | #ifdef FBGEMM_USE_SHARED_TIMED_MUTEX |
60 | std::shared_lock<std::shared_timed_mutex> sharedLock(mutex_); |
61 | #else |
62 | std::unique_lock<std::mutex> uniqueLock(mutex_); |
63 | #endif |
64 | |
65 | // Check for existence of the key |
66 | auto it = values_.find(key); |
67 | if (it != values_.end()) { |
68 | return it->second.get(); |
69 | } else { |
70 | #ifdef FBGEMM_USE_SHARED_TIMED_MUTEX |
71 | sharedLock.unlock(); |
72 | std::unique_lock<std::shared_timed_mutex> uniqueLock(mutex_); |
73 | |
74 | // Need to look up again because there could be race condition from |
75 | // the time gap between sharedLock.unlock() and creating uniqueLock. |
76 | it = values_.find(key); |
77 | if (it == values_.end()) { |
78 | #endif |
79 | std::promise<VALUE> returnPromise; |
80 | values_[key] = returnPromise.get_future().share(); |
81 | |
82 | uniqueLock.unlock(); |
83 | // The value (code) generation is not happening under a lock |
84 | VALUE val = generatorFunction(); |
85 | returnPromise.set_value(val); |
86 | return val; |
87 | #ifdef FBGEMM_USE_SHARED_TIMED_MUTEX |
88 | } else { |
89 | return it->second.get(); |
90 | } |
91 | #endif |
92 | } |
93 | } |
94 | }; |
95 | |
96 | // This class must be used as a static variable. |
97 | template <typename KEY, typename VALUE> |
98 | class CodeCache<KEY, VALUE, /*THREAD_LOCAL=*/true> { |
99 | private: |
100 | #ifdef FBCODE_CAFFE2 |
101 | static folly::F14FastMap<KEY, VALUE>& getValues_() { |
102 | static thread_local folly::F14FastMap<KEY, VALUE> |
103 | values_; /* library-local */ |
104 | return values_; |
105 | } |
106 | #else |
107 | static std::map<KEY, VALUE>& getValues_() { |
108 | static thread_local std::map<KEY, VALUE> values_; |
109 | return values_; |
110 | } |
111 | #endif |
112 | |
113 | public: |
114 | CodeCache(const CodeCache&) = delete; |
115 | CodeCache& operator=(const CodeCache&) = delete; |
116 | |
117 | CodeCache() {} |
118 | |
119 | template <typename GENFUNC> |
120 | VALUE getOrCreate(const KEY& key, GENFUNC generatorFunction) { |
121 | // Check for existence of the key |
122 | auto it = getValues_().find(key); |
123 | if (it != getValues_().end()) { |
124 | return it->second; |
125 | } else { |
126 | VALUE val = generatorFunction(); |
127 | getValues_()[key] = val; |
128 | return val; |
129 | } |
130 | } |
131 | }; |
132 | |
133 | } // namespace fbgemm |
134 | |