1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "primitive_cache.hpp"
18#include "c_types_map.hpp"
19#include "primitive.hpp"
20#include "primitive_desc_iface.hpp"
21#include "primitive_iface.hpp"
22#include "rw_mutex.hpp"
23#include "z_magic.hpp"
24
25#if DNNL_CPU_RUNTIME != DNNL_RUNTIME_NONE
26#include "cpu/platform.hpp"
27#else
28#include <chrono>
29#endif
30
31#include <algorithm>
32#include <unordered_map>
33
34#ifdef _WIN32
35#include <windows.h>
36#endif
37
38namespace dnnl {
39namespace impl {
40
41namespace {
42
43size_t get_timestamp() {
44#if DNNL_CPU_RUNTIME != DNNL_RUNTIME_NONE
45 return cpu::platform::get_timestamp();
46#else
47 return std::chrono::steady_clock::now().time_since_epoch().count();
48#endif
49}
50
51} // namespace
52
53primitive_cache_t &primitive_cache() {
54#ifndef DNNL_DISABLE_PRIMITIVE_CACHE
55 static const int capacity
56 = getenv_int_user("PRIMITIVE_CACHE_CAPACITY", 1024);
57#else
58 static const int capacity = 0;
59#endif
60 static lru_primitive_cache_t cache(capacity);
61 return cache;
62}
63
64// Undocumented API, for testing only
65status_t get_primitive_cache_size(int *size) {
66 if (size == nullptr) return dnnl::impl::status::invalid_arguments;
67 *size = 0;
68#ifndef DNNL_DISABLE_PRIMITIVE_CACHE
69 *size = primitive_cache().get_size();
70#endif
71 return dnnl::impl::status::success;
72}
73
74bool is_pd_in_cache(const primitive_desc_iface_t *pd_iface) {
75 const auto *pd = pd_iface->impl().get();
76 const auto *engine = pd_iface->engine();
77 primitive_hashing::key_t key(pd, engine);
78 return bool(primitive_cache().get_pd(key));
79}
80
81bool is_primitive_in_cache(const primitive_iface_t *p_iface) {
82 return is_pd_in_cache(p_iface->pd());
83}
84
85size_t set_primitive_cache_capacity_without_clearing(size_t capacity) {
86 size_t old_capacity = primitive_cache().get_capacity();
87 static_cast<lru_primitive_cache_t &>((primitive_cache())).capacity_
88 = capacity;
89 return old_capacity;
90}
91
92status_t lru_primitive_cache_t::set_capacity(int capacity) {
93 utils::lock_write_t lock_w(rw_mutex());
94 capacity_ = (size_t)capacity;
95 // Check if number of entries exceeds the new capacity
96 if (cache_mapper().size() > capacity_) {
97 // Evict excess entries
98 size_t n_excess_entries = cache_mapper().size() - capacity_;
99 evict(n_excess_entries);
100 }
101 return status::success;
102}
103
104int lru_primitive_cache_t::get_capacity() const {
105 utils::lock_read_t lock_r(rw_mutex());
106 return (int)capacity_;
107}
108
109// For undocumented API
110int lru_primitive_cache_t::get_size() const {
111 utils::lock_read_t lock_r(rw_mutex());
112 return (int)cache_mapper().size();
113}
114
115lru_primitive_cache_t::value_t lru_primitive_cache_t::get_or_add(
116 const key_t &key, const value_t &value) {
117 // 1. Section with shared access (read lock)
118 lock_read();
119 // Check if the cache is enabled.
120 if (capacity_ == 0) {
121 unlock_read();
122 return value_t();
123 }
124 // Check if the requested entry is present in the cache (likely cache_hit)
125 auto e = get(key);
126 if (e.valid()) {
127 unlock_read();
128 return e;
129 }
130
131 unlock_read();
132
133 // 2. Section with exclusive access (write lock).
134 // In a multithreaded scenario, in the context of one thread the cache
135 // may have changed by another thread between releasing the read lock and
136 // acquiring the write lock (a.k.a. ABA problem), therefore additional
137 // checks have to be performed for correctness.
138 // Double check the capacity due to possible race condition
139 lock_write();
140 if (capacity_ == 0) {
141 unlock_write();
142 return value_t();
143 }
144
145 // Double check if the requested entry is present in the cache (unlikely
146 // cache_hit).
147 e = get(key);
148 if (!e.valid()) {
149 // If the entry is missing in the cache then add it (cache_miss)
150 add(key, value);
151 }
152 unlock_write();
153 return e;
154}
155
156void lru_primitive_cache_t::add(const key_t &key, const value_t &value) {
157 // std::list::size() method has linear complexity. Check the primitive cache
158 // size using std::unordered_map::size();
159 if (cache_mapper().size() == capacity_) {
160 // Evict the least recently used entry
161 evict(1);
162 }
163
164 size_t timestamp = get_timestamp();
165
166 auto res = cache_mapper().emplace(std::piecewise_construct,
167 std::forward_as_tuple(key),
168 std::forward_as_tuple(value, timestamp));
169 MAYBE_UNUSED(res);
170 assert(res.second);
171}
172
173lru_primitive_cache_t::value_t lru_primitive_cache_t::get(const key_t &key) {
174 auto it = cache_mapper().find(key);
175 if (it == cache_mapper().end()) return value_t();
176
177 size_t timestamp = get_timestamp();
178 it->second.timestamp_.store(timestamp);
179 // Return the entry
180 return it->second.value_;
181}
182
183std::shared_ptr<primitive_desc_t> lru_primitive_cache_t::get_pd(
184 const key_t &key) {
185 lock_read();
186 if (capacity_ == 0) {
187 unlock_read();
188 return nullptr;
189 }
190 auto e = get(key);
191 unlock_read();
192
193 if (e.valid()) return e.get().primitive->pd();
194 return nullptr;
195}
196
197void lru_primitive_cache_t::remove_if_invalidated(const key_t &key) {
198 lock_write();
199
200 if (capacity_ == 0) {
201 unlock_write();
202 return;
203 }
204
205 auto it = cache_mapper().find(key);
206 if (it == cache_mapper().end()) {
207 // The entry has been already evicted at this point
208 unlock_write();
209 return;
210 }
211
212 const auto &value = it->second.value_;
213 if (value.get().primitive) {
214 // If the entry is not invalidated
215 unlock_write();
216 return;
217 }
218
219 // Remove the invalidated entry
220 cache_mapper().erase(it);
221 unlock_write();
222}
223
224void lru_primitive_cache_t::update_entry(
225 const key_t &key, const primitive_desc_t *pd) {
226 lock_write();
227
228 if (capacity_ == 0) {
229 unlock_write();
230 return;
231 }
232
233 auto it = cache_mapper().find(key);
234
235 // There is nothing to do in two cases:
236 // 1. The requested entry is not in the cache because it has been evicted
237 // by another thread
238 // 2. After the requested entry had been evicted it was inserted again
239 // by another thread
240 if (it == cache_mapper().end()
241 || it->first.thread_id() != key.thread_id()) {
242 unlock_write();
243 return;
244 }
245
246 const auto *op_desc = pd->op_desc();
247 const auto *attr = pd->attr();
248
249 // Update key in cache_mapper()
250 it->first.op_desc_ = op_desc;
251 it->first.attr_ = attr;
252 unlock_write();
253}
254
255// Evicts n the least recently used entries
256void lru_primitive_cache_t::evict(size_t n) {
257 using v_t = std::unordered_map<key_t, timed_entry_t>::value_type;
258
259 if (n == capacity_) {
260 cache_mapper().clear();
261 return;
262 }
263
264 for (size_t e = 0; e < n; e++) {
265 // Find the smallest timestamp
266 // TODO: revisit the eviction algorithm due to O(n) complexity, E.g.
267 // maybe evict multiple entries at once.
268 auto it = std::min_element(cache_mapper().begin(), cache_mapper().end(),
269 [&](const v_t &left, const v_t &right) {
270 // By default, load() and operator T use sequentially
271 // consistent memory ordering, which enforces writing the
272 // timestamps into registers in the same exact order they
273 // are read from the CPU cache line. Since eviction is
274 // performed under a write lock, this order is not
275 // important, therefore we can safely use the weakest memory
276 // ordering (relaxed). This brings about a few microseconds
277 // performance improvement for default primitive cache
278 // capacity.
279 return left.second.timestamp_.load(
280 std::memory_order_relaxed)
281 < right.second.timestamp_.load(
282 std::memory_order_relaxed);
283 });
284 auto res = cache_mapper().erase(it->first);
285 MAYBE_UNUSED(res);
286 assert(res);
287 }
288}
289
290lru_primitive_cache_t::~lru_primitive_cache_t() {
291 if (cache_mapper().empty()) return;
292
293#if defined(_WIN32) \
294 && (defined(DNNL_WITH_SYCL) || DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL)
295 // The ntdll.dll library is located in system32 therefore setting additional
296 // environment is not required.
297 HMODULE handle = LoadLibraryExA(
298 "ntdll.dll", nullptr, LOAD_LIBRARY_SEARCH_SYSTEM32);
299 if (!handle) {
300 cache_mapper_.release();
301 return;
302 }
303
304 // RtlDllShutdownInProgress returns TRUE if the whole process terminates and
305 // FALSE if DLL is being unloaded dynamically or if it’s called from an
306 // executable.
307 auto f = reinterpret_cast<BOOLEAN (*)(void)>(
308 GetProcAddress(handle, "RtlDllShutdownInProgress"));
309 if (!f) {
310 auto ret = FreeLibrary(handle);
311 assert(ret);
312 MAYBE_UNUSED(ret);
313 cache_mapper_.release();
314 return;
315 }
316
317 bool is_process_termination_in_progress = f();
318
319 auto ret = FreeLibrary(handle);
320 assert(ret);
321 MAYBE_UNUSED(ret);
322
323 if (is_process_termination_in_progress) {
324 // The whole process is being terminated hence destroying content of
325 // the primitive cache cannot be done safely. However we can check
326 // all entries and remove those that are not affected e.g. native CPU.
327 for (auto it = cache_mapper().begin(); it != cache_mapper().end();) {
328 const auto &engine_id = it->first.engine_id_;
329 if (engine_id.kind() == engine_kind::cpu
330 && is_native_runtime(engine_id.runtime_kind())) {
331 it = cache_mapper().erase(it);
332 } else {
333 ++it;
334 }
335 }
336 cache_mapper_.release();
337 } else {
338 // Three scenarios possible:
339 // 1. oneDNN is being dynamically unloaded
340 // 2. Another dynamic library that contains statically linked oneDNN is
341 // dynamically unloaded
342 // 3. oneDNN is statically linked in an executable which is done and now
343 // the process terminates
344 // In all these scenarios content of the primitive cache can be safely
345 // destroyed.
346 cache_mapper_.reset();
347 }
348#else
349 // Always destroy the content of the primitive cache for non-Windows OSes,
350 // and non-sycl and non-ocl runtimes because there is no a problem with
351 // library unloading order in such cases.
352 cache_mapper_.reset();
353#endif
354}
355
356} // namespace impl
357} // namespace dnnl
358
359// API
360dnnl::impl::status_t dnnl_get_primitive_cache_capacity(int *capacity) {
361 if (capacity == nullptr) return dnnl::impl::status::invalid_arguments;
362 *capacity = 0;
363#ifndef DNNL_DISABLE_PRIMITIVE_CACHE
364 *capacity = dnnl::impl::primitive_cache().get_capacity();
365#endif
366 return dnnl::impl::status::success;
367}
368
369dnnl::impl::status_t dnnl_set_primitive_cache_capacity(int capacity) {
370 if (capacity < 0) return dnnl::impl::status::invalid_arguments;
371#ifndef DNNL_DISABLE_PRIMITIVE_CACHE
372 return dnnl::impl::primitive_cache().set_capacity(capacity);
373#endif
374 return dnnl::impl::status::success;
375}
376