1/*******************************************************************************
2* Copyright 2019-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef COMMON_PRIMITIVE_CACHE_HPP
18#define COMMON_PRIMITIVE_CACHE_HPP
19
20#include <future>
21#include <memory>
22#include <thread>
23#include <unordered_map>
24
25#include "c_types_map.hpp"
26#include "oneapi/dnnl/dnnl.h"
27#include "primitive_hashing.hpp"
28#include "rw_mutex.hpp"
29#include "type_helpers.hpp"
30
31namespace dnnl {
32namespace impl {
33
34struct primitive_t;
35struct primitive_cache_t : public c_compatible {
36 struct cache_value_t {
37 std::shared_ptr<primitive_t> primitive;
38 status_t status;
39 };
40 using key_t = primitive_hashing::key_t;
41 using value_t = std::shared_future<cache_value_t>;
42
43 virtual ~primitive_cache_t() = default;
44
45 virtual status_t set_capacity(int capacity) = 0;
46 virtual int get_capacity() const = 0;
47
48 virtual value_t get_or_add(const key_t &key, const value_t &value) = 0;
49 virtual void remove_if_invalidated(const key_t &key) = 0;
50 virtual void update_entry(const key_t &key, const primitive_desc_t *pd) = 0;
51
52 virtual int get_size() const = 0;
53
54 virtual std::shared_ptr<primitive_desc_t> get_pd(const key_t &key) = 0;
55
56protected:
57 static utils::rw_mutex_t &rw_mutex() {
58 static utils::rw_mutex_t mutex;
59 return mutex;
60 }
61
62 void lock_read() { rw_mutex().lock_read(); }
63 void lock_write() { rw_mutex().lock_write(); }
64 void unlock_read() { rw_mutex().unlock_read(); }
65 void unlock_write() { rw_mutex().unlock_write(); }
66};
67
68// The cache uses LRU replacement policy
69struct lru_primitive_cache_t : public primitive_cache_t {
70 lru_primitive_cache_t(int capacity) : capacity_(capacity) {
71 cache_mapper_ = utils::make_unique<
72 std::unordered_map<key_t, timed_entry_t>>();
73 }
74
75 ~lru_primitive_cache_t() override;
76
77 status_t set_capacity(int capacity) override;
78 int get_capacity() const override;
79
80 value_t get_or_add(const key_t &key, const value_t &value) override;
81 void remove_if_invalidated(const key_t &key) override;
82 void update_entry(const key_t &key, const primitive_desc_t *pd) override;
83
84 int get_size() const override;
85
86 std::shared_ptr<primitive_desc_t> get_pd(const key_t &key) override;
87
88private:
89 void evict(size_t n);
90 void add(const key_t &key, const value_t &value);
91 value_t get(const key_t &key);
92
93 size_t capacity_;
94 struct timed_entry_t {
95 value_t value_;
96 std::atomic<size_t> timestamp_;
97 timed_entry_t(const value_t &value, size_t timestamp)
98 : value_(value), timestamp_(timestamp) {}
99 };
100
101 std::unordered_map<key_t, timed_entry_t> &cache_mapper() {
102 return *cache_mapper_;
103 }
104
105 const std::unordered_map<key_t, timed_entry_t> &cache_mapper() const {
106 return *cache_mapper_;
107 }
108
109 // Each entry in the cache has a corresponding key and timestamp.
110 // NOTE: pairs that contain atomics cannot be stored in an unordered_map *as
111 // an element*, since it invokes the copy constructor of std::atomic, which
112 // is deleted.
113 std::unique_ptr<std::unordered_map<key_t, timed_entry_t>> cache_mapper_;
114
115 // Used for testing.
116 friend size_t DNNL_API set_primitive_cache_capacity_without_clearing(
117 size_t capacity);
118};
119
120primitive_cache_t &primitive_cache();
121
122// Undocumented API for testing.
123status_t DNNL_API get_primitive_cache_size(int *size);
124bool DNNL_API is_primitive_in_cache(const primitive_iface_t *p_iface);
125bool DNNL_API is_pd_in_cache(const primitive_desc_iface_t *pd_iface);
126size_t DNNL_API set_primitive_cache_capacity_without_clearing(size_t capacity);
127
128} // namespace impl
129} // namespace dnnl
130#endif
131
132// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
133