1 | /******************************************************************************* |
2 | * Copyright 2019-2021 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef COMMON_PRIMITIVE_CACHE_HPP |
18 | #define COMMON_PRIMITIVE_CACHE_HPP |
19 | |
20 | #include <future> |
21 | #include <memory> |
22 | #include <thread> |
23 | #include <unordered_map> |
24 | |
25 | #include "c_types_map.hpp" |
26 | #include "oneapi/dnnl/dnnl.h" |
27 | #include "primitive_hashing.hpp" |
28 | #include "rw_mutex.hpp" |
29 | #include "type_helpers.hpp" |
30 | |
31 | namespace dnnl { |
32 | namespace impl { |
33 | |
34 | struct primitive_t; |
35 | struct primitive_cache_t : public c_compatible { |
36 | struct cache_value_t { |
37 | std::shared_ptr<primitive_t> primitive; |
38 | status_t status; |
39 | }; |
40 | using key_t = primitive_hashing::key_t; |
41 | using value_t = std::shared_future<cache_value_t>; |
42 | |
43 | virtual ~primitive_cache_t() = default; |
44 | |
45 | virtual status_t set_capacity(int capacity) = 0; |
46 | virtual int get_capacity() const = 0; |
47 | |
48 | virtual value_t get_or_add(const key_t &key, const value_t &value) = 0; |
49 | virtual void remove_if_invalidated(const key_t &key) = 0; |
50 | virtual void update_entry(const key_t &key, const primitive_desc_t *pd) = 0; |
51 | |
52 | virtual int get_size() const = 0; |
53 | |
54 | virtual std::shared_ptr<primitive_desc_t> get_pd(const key_t &key) = 0; |
55 | |
56 | protected: |
57 | static utils::rw_mutex_t &rw_mutex() { |
58 | static utils::rw_mutex_t mutex; |
59 | return mutex; |
60 | } |
61 | |
62 | void lock_read() { rw_mutex().lock_read(); } |
63 | void lock_write() { rw_mutex().lock_write(); } |
64 | void unlock_read() { rw_mutex().unlock_read(); } |
65 | void unlock_write() { rw_mutex().unlock_write(); } |
66 | }; |
67 | |
68 | // The cache uses LRU replacement policy |
69 | struct lru_primitive_cache_t : public primitive_cache_t { |
70 | lru_primitive_cache_t(int capacity) : capacity_(capacity) { |
71 | cache_mapper_ = utils::make_unique< |
72 | std::unordered_map<key_t, timed_entry_t>>(); |
73 | } |
74 | |
75 | ~lru_primitive_cache_t() override; |
76 | |
77 | status_t set_capacity(int capacity) override; |
78 | int get_capacity() const override; |
79 | |
80 | value_t get_or_add(const key_t &key, const value_t &value) override; |
81 | void remove_if_invalidated(const key_t &key) override; |
82 | void update_entry(const key_t &key, const primitive_desc_t *pd) override; |
83 | |
84 | int get_size() const override; |
85 | |
86 | std::shared_ptr<primitive_desc_t> get_pd(const key_t &key) override; |
87 | |
88 | private: |
89 | void evict(size_t n); |
90 | void add(const key_t &key, const value_t &value); |
91 | value_t get(const key_t &key); |
92 | |
93 | size_t capacity_; |
94 | struct timed_entry_t { |
95 | value_t value_; |
96 | std::atomic<size_t> timestamp_; |
97 | timed_entry_t(const value_t &value, size_t timestamp) |
98 | : value_(value), timestamp_(timestamp) {} |
99 | }; |
100 | |
101 | std::unordered_map<key_t, timed_entry_t> &cache_mapper() { |
102 | return *cache_mapper_; |
103 | } |
104 | |
105 | const std::unordered_map<key_t, timed_entry_t> &cache_mapper() const { |
106 | return *cache_mapper_; |
107 | } |
108 | |
109 | // Each entry in the cache has a corresponding key and timestamp. |
110 | // NOTE: pairs that contain atomics cannot be stored in an unordered_map *as |
111 | // an element*, since it invokes the copy constructor of std::atomic, which |
112 | // is deleted. |
113 | std::unique_ptr<std::unordered_map<key_t, timed_entry_t>> cache_mapper_; |
114 | |
115 | // Used for testing. |
116 | friend size_t DNNL_API set_primitive_cache_capacity_without_clearing( |
117 | size_t capacity); |
118 | }; |
119 | |
120 | primitive_cache_t &primitive_cache(); |
121 | |
122 | // Undocumented API for testing. |
123 | status_t DNNL_API get_primitive_cache_size(int *size); |
124 | bool DNNL_API is_primitive_in_cache(const primitive_iface_t *p_iface); |
125 | bool DNNL_API is_pd_in_cache(const primitive_desc_iface_t *pd_iface); |
126 | size_t DNNL_API set_primitive_cache_capacity_without_clearing(size_t capacity); |
127 | |
128 | } // namespace impl |
129 | } // namespace dnnl |
130 | #endif |
131 | |
132 | // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s |
133 | |