1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "primitive_cache.hpp" |
18 | #include "c_types_map.hpp" |
19 | #include "primitive.hpp" |
20 | #include "primitive_desc_iface.hpp" |
21 | #include "primitive_iface.hpp" |
22 | #include "rw_mutex.hpp" |
23 | #include "z_magic.hpp" |
24 | |
25 | #if DNNL_CPU_RUNTIME != DNNL_RUNTIME_NONE |
26 | #include "cpu/platform.hpp" |
27 | #else |
28 | #include <chrono> |
29 | #endif |
30 | |
31 | #include <algorithm> |
32 | #include <unordered_map> |
33 | |
34 | #ifdef _WIN32 |
35 | #include <windows.h> |
36 | #endif |
37 | |
38 | namespace dnnl { |
39 | namespace impl { |
40 | |
41 | namespace { |
42 | |
43 | size_t get_timestamp() { |
44 | #if DNNL_CPU_RUNTIME != DNNL_RUNTIME_NONE |
45 | return cpu::platform::get_timestamp(); |
46 | #else |
47 | return std::chrono::steady_clock::now().time_since_epoch().count(); |
48 | #endif |
49 | } |
50 | |
51 | } // namespace |
52 | |
53 | primitive_cache_t &primitive_cache() { |
54 | #ifndef DNNL_DISABLE_PRIMITIVE_CACHE |
55 | static const int capacity |
56 | = getenv_int_user("PRIMITIVE_CACHE_CAPACITY" , 1024); |
57 | #else |
58 | static const int capacity = 0; |
59 | #endif |
60 | static lru_primitive_cache_t cache(capacity); |
61 | return cache; |
62 | } |
63 | |
64 | // Undocumented API, for testing only |
65 | status_t get_primitive_cache_size(int *size) { |
66 | if (size == nullptr) return dnnl::impl::status::invalid_arguments; |
67 | *size = 0; |
68 | #ifndef DNNL_DISABLE_PRIMITIVE_CACHE |
69 | *size = primitive_cache().get_size(); |
70 | #endif |
71 | return dnnl::impl::status::success; |
72 | } |
73 | |
74 | bool is_pd_in_cache(const primitive_desc_iface_t *pd_iface) { |
75 | const auto *pd = pd_iface->impl().get(); |
76 | const auto *engine = pd_iface->engine(); |
77 | primitive_hashing::key_t key(pd, engine); |
78 | return bool(primitive_cache().get_pd(key)); |
79 | } |
80 | |
81 | bool is_primitive_in_cache(const primitive_iface_t *p_iface) { |
82 | return is_pd_in_cache(p_iface->pd()); |
83 | } |
84 | |
85 | size_t set_primitive_cache_capacity_without_clearing(size_t capacity) { |
86 | size_t old_capacity = primitive_cache().get_capacity(); |
87 | static_cast<lru_primitive_cache_t &>((primitive_cache())).capacity_ |
88 | = capacity; |
89 | return old_capacity; |
90 | } |
91 | |
92 | status_t lru_primitive_cache_t::set_capacity(int capacity) { |
93 | utils::lock_write_t lock_w(rw_mutex()); |
94 | capacity_ = (size_t)capacity; |
95 | // Check if number of entries exceeds the new capacity |
96 | if (cache_mapper().size() > capacity_) { |
97 | // Evict excess entries |
98 | size_t n_excess_entries = cache_mapper().size() - capacity_; |
99 | evict(n_excess_entries); |
100 | } |
101 | return status::success; |
102 | } |
103 | |
104 | int lru_primitive_cache_t::get_capacity() const { |
105 | utils::lock_read_t lock_r(rw_mutex()); |
106 | return (int)capacity_; |
107 | } |
108 | |
109 | // For undocumented API |
110 | int lru_primitive_cache_t::get_size() const { |
111 | utils::lock_read_t lock_r(rw_mutex()); |
112 | return (int)cache_mapper().size(); |
113 | } |
114 | |
115 | lru_primitive_cache_t::value_t lru_primitive_cache_t::get_or_add( |
116 | const key_t &key, const value_t &value) { |
117 | // 1. Section with shared access (read lock) |
118 | lock_read(); |
119 | // Check if the cache is enabled. |
120 | if (capacity_ == 0) { |
121 | unlock_read(); |
122 | return value_t(); |
123 | } |
124 | // Check if the requested entry is present in the cache (likely cache_hit) |
125 | auto e = get(key); |
126 | if (e.valid()) { |
127 | unlock_read(); |
128 | return e; |
129 | } |
130 | |
131 | unlock_read(); |
132 | |
133 | // 2. Section with exclusive access (write lock). |
134 | // In a multithreaded scenario, in the context of one thread the cache |
135 | // may have changed by another thread between releasing the read lock and |
136 | // acquiring the write lock (a.k.a. ABA problem), therefore additional |
137 | // checks have to be performed for correctness. |
138 | // Double check the capacity due to possible race condition |
139 | lock_write(); |
140 | if (capacity_ == 0) { |
141 | unlock_write(); |
142 | return value_t(); |
143 | } |
144 | |
145 | // Double check if the requested entry is present in the cache (unlikely |
146 | // cache_hit). |
147 | e = get(key); |
148 | if (!e.valid()) { |
149 | // If the entry is missing in the cache then add it (cache_miss) |
150 | add(key, value); |
151 | } |
152 | unlock_write(); |
153 | return e; |
154 | } |
155 | |
156 | void lru_primitive_cache_t::add(const key_t &key, const value_t &value) { |
157 | // std::list::size() method has linear complexity. Check the primitive cache |
158 | // size using std::unordered_map::size(); |
159 | if (cache_mapper().size() == capacity_) { |
160 | // Evict the least recently used entry |
161 | evict(1); |
162 | } |
163 | |
164 | size_t timestamp = get_timestamp(); |
165 | |
166 | auto res = cache_mapper().emplace(std::piecewise_construct, |
167 | std::forward_as_tuple(key), |
168 | std::forward_as_tuple(value, timestamp)); |
169 | MAYBE_UNUSED(res); |
170 | assert(res.second); |
171 | } |
172 | |
173 | lru_primitive_cache_t::value_t lru_primitive_cache_t::get(const key_t &key) { |
174 | auto it = cache_mapper().find(key); |
175 | if (it == cache_mapper().end()) return value_t(); |
176 | |
177 | size_t timestamp = get_timestamp(); |
178 | it->second.timestamp_.store(timestamp); |
179 | // Return the entry |
180 | return it->second.value_; |
181 | } |
182 | |
183 | std::shared_ptr<primitive_desc_t> lru_primitive_cache_t::get_pd( |
184 | const key_t &key) { |
185 | lock_read(); |
186 | if (capacity_ == 0) { |
187 | unlock_read(); |
188 | return nullptr; |
189 | } |
190 | auto e = get(key); |
191 | unlock_read(); |
192 | |
193 | if (e.valid()) return e.get().primitive->pd(); |
194 | return nullptr; |
195 | } |
196 | |
197 | void lru_primitive_cache_t::remove_if_invalidated(const key_t &key) { |
198 | lock_write(); |
199 | |
200 | if (capacity_ == 0) { |
201 | unlock_write(); |
202 | return; |
203 | } |
204 | |
205 | auto it = cache_mapper().find(key); |
206 | if (it == cache_mapper().end()) { |
207 | // The entry has been already evicted at this point |
208 | unlock_write(); |
209 | return; |
210 | } |
211 | |
212 | const auto &value = it->second.value_; |
213 | if (value.get().primitive) { |
214 | // If the entry is not invalidated |
215 | unlock_write(); |
216 | return; |
217 | } |
218 | |
219 | // Remove the invalidated entry |
220 | cache_mapper().erase(it); |
221 | unlock_write(); |
222 | } |
223 | |
224 | void lru_primitive_cache_t::update_entry( |
225 | const key_t &key, const primitive_desc_t *pd) { |
226 | lock_write(); |
227 | |
228 | if (capacity_ == 0) { |
229 | unlock_write(); |
230 | return; |
231 | } |
232 | |
233 | auto it = cache_mapper().find(key); |
234 | |
235 | // There is nothing to do in two cases: |
236 | // 1. The requested entry is not in the cache because it has been evicted |
237 | // by another thread |
238 | // 2. After the requested entry had been evicted it was inserted again |
239 | // by another thread |
240 | if (it == cache_mapper().end() |
241 | || it->first.thread_id() != key.thread_id()) { |
242 | unlock_write(); |
243 | return; |
244 | } |
245 | |
246 | const auto *op_desc = pd->op_desc(); |
247 | const auto *attr = pd->attr(); |
248 | |
249 | // Update key in cache_mapper() |
250 | it->first.op_desc_ = op_desc; |
251 | it->first.attr_ = attr; |
252 | unlock_write(); |
253 | } |
254 | |
255 | // Evicts n the least recently used entries |
256 | void lru_primitive_cache_t::evict(size_t n) { |
257 | using v_t = std::unordered_map<key_t, timed_entry_t>::value_type; |
258 | |
259 | if (n == capacity_) { |
260 | cache_mapper().clear(); |
261 | return; |
262 | } |
263 | |
264 | for (size_t e = 0; e < n; e++) { |
265 | // Find the smallest timestamp |
266 | // TODO: revisit the eviction algorithm due to O(n) complexity, E.g. |
267 | // maybe evict multiple entries at once. |
268 | auto it = std::min_element(cache_mapper().begin(), cache_mapper().end(), |
269 | [&](const v_t &left, const v_t &right) { |
270 | // By default, load() and operator T use sequentially |
271 | // consistent memory ordering, which enforces writing the |
272 | // timestamps into registers in the same exact order they |
273 | // are read from the CPU cache line. Since eviction is |
274 | // performed under a write lock, this order is not |
275 | // important, therefore we can safely use the weakest memory |
276 | // ordering (relaxed). This brings about a few microseconds |
277 | // performance improvement for default primitive cache |
278 | // capacity. |
279 | return left.second.timestamp_.load( |
280 | std::memory_order_relaxed) |
281 | < right.second.timestamp_.load( |
282 | std::memory_order_relaxed); |
283 | }); |
284 | auto res = cache_mapper().erase(it->first); |
285 | MAYBE_UNUSED(res); |
286 | assert(res); |
287 | } |
288 | } |
289 | |
290 | lru_primitive_cache_t::~lru_primitive_cache_t() { |
291 | if (cache_mapper().empty()) return; |
292 | |
293 | #if defined(_WIN32) \ |
294 | && (defined(DNNL_WITH_SYCL) || DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL) |
295 | // The ntdll.dll library is located in system32 therefore setting additional |
296 | // environment is not required. |
297 | HMODULE handle = LoadLibraryExA( |
298 | "ntdll.dll" , nullptr, LOAD_LIBRARY_SEARCH_SYSTEM32); |
299 | if (!handle) { |
300 | cache_mapper_.release(); |
301 | return; |
302 | } |
303 | |
304 | // RtlDllShutdownInProgress returns TRUE if the whole process terminates and |
305 | // FALSE if DLL is being unloaded dynamically or if it’s called from an |
306 | // executable. |
307 | auto f = reinterpret_cast<BOOLEAN (*)(void)>( |
308 | GetProcAddress(handle, "RtlDllShutdownInProgress" )); |
309 | if (!f) { |
310 | auto ret = FreeLibrary(handle); |
311 | assert(ret); |
312 | MAYBE_UNUSED(ret); |
313 | cache_mapper_.release(); |
314 | return; |
315 | } |
316 | |
317 | bool is_process_termination_in_progress = f(); |
318 | |
319 | auto ret = FreeLibrary(handle); |
320 | assert(ret); |
321 | MAYBE_UNUSED(ret); |
322 | |
323 | if (is_process_termination_in_progress) { |
324 | // The whole process is being terminated hence destroying content of |
325 | // the primitive cache cannot be done safely. However we can check |
326 | // all entries and remove those that are not affected e.g. native CPU. |
327 | for (auto it = cache_mapper().begin(); it != cache_mapper().end();) { |
328 | const auto &engine_id = it->first.engine_id_; |
329 | if (engine_id.kind() == engine_kind::cpu |
330 | && is_native_runtime(engine_id.runtime_kind())) { |
331 | it = cache_mapper().erase(it); |
332 | } else { |
333 | ++it; |
334 | } |
335 | } |
336 | cache_mapper_.release(); |
337 | } else { |
338 | // Three scenarios possible: |
339 | // 1. oneDNN is being dynamically unloaded |
340 | // 2. Another dynamic library that contains statically linked oneDNN is |
341 | // dynamically unloaded |
342 | // 3. oneDNN is statically linked in an executable which is done and now |
343 | // the process terminates |
344 | // In all these scenarios content of the primitive cache can be safely |
345 | // destroyed. |
346 | cache_mapper_.reset(); |
347 | } |
348 | #else |
349 | // Always destroy the content of the primitive cache for non-Windows OSes, |
350 | // and non-sycl and non-ocl runtimes because there is no a problem with |
351 | // library unloading order in such cases. |
352 | cache_mapper_.reset(); |
353 | #endif |
354 | } |
355 | |
356 | } // namespace impl |
357 | } // namespace dnnl |
358 | |
359 | // API |
360 | dnnl::impl::status_t dnnl_get_primitive_cache_capacity(int *capacity) { |
361 | if (capacity == nullptr) return dnnl::impl::status::invalid_arguments; |
362 | *capacity = 0; |
363 | #ifndef DNNL_DISABLE_PRIMITIVE_CACHE |
364 | *capacity = dnnl::impl::primitive_cache().get_capacity(); |
365 | #endif |
366 | return dnnl::impl::status::success; |
367 | } |
368 | |
369 | dnnl::impl::status_t dnnl_set_primitive_cache_capacity(int capacity) { |
370 | if (capacity < 0) return dnnl::impl::status::invalid_arguments; |
371 | #ifndef DNNL_DISABLE_PRIMITIVE_CACHE |
372 | return dnnl::impl::primitive_cache().set_capacity(capacity); |
373 | #endif |
374 | return dnnl::impl::status::success; |
375 | } |
376 | |