1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef COMMON_THREAD_LOCAL_STORAGE_HPP |
18 | #define COMMON_THREAD_LOCAL_STORAGE_HPP |
19 | |
20 | #include <assert.h> |
21 | #include <thread> |
22 | #include <utility> |
23 | #include <unordered_map> |
24 | |
25 | #include "rw_mutex.hpp" |
26 | #include "z_magic.hpp" |
27 | |
28 | namespace dnnl { |
29 | namespace impl { |
30 | namespace utils { |
31 | |
32 | template <typename T> |
33 | struct thread_local_storage_t { |
34 | thread_local_storage_t() = default; |
35 | |
36 | DNNL_DISALLOW_COPY_AND_ASSIGN(thread_local_storage_t); |
37 | |
38 | template <typename U> |
39 | T &set(U &&value) { |
40 | utils::lock_write_t lock_w(mutex_); |
41 | const auto tid = std::this_thread::get_id(); |
42 | assert(storage_.count(tid) == 0); |
43 | auto it = storage_.emplace(tid, std::forward<U>(value)); |
44 | assert(it.second); |
45 | return it.first->second; |
46 | } |
47 | |
48 | bool is_set() { |
49 | utils::lock_read_t lock_r(mutex_); |
50 | return storage_.find(std::this_thread::get_id()) != storage_.end(); |
51 | } |
52 | |
53 | T &get() { |
54 | utils::lock_read_t lock_r(mutex_); |
55 | return storage_.at(std::this_thread::get_id()); |
56 | } |
57 | |
58 | template <typename U> |
59 | T &get(U &&def_value) { |
60 | { |
61 | utils::lock_read_t lock_r(mutex_); |
62 | auto it = storage_.find(std::this_thread::get_id()); |
63 | if (it != storage_.end()) return it->second; |
64 | } |
65 | return set(std::forward<U>(def_value)); |
66 | } |
67 | |
68 | private: |
69 | std::unordered_map<std::thread::id, T> storage_; |
70 | utils::rw_mutex_t mutex_; |
71 | }; |
72 | |
73 | } // namespace utils |
74 | } // namespace impl |
75 | } // namespace dnnl |
76 | |
77 | #endif |
78 | |