1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef COMMON_THREAD_LOCAL_STORAGE_HPP
18#define COMMON_THREAD_LOCAL_STORAGE_HPP
19
20#include <assert.h>
21#include <thread>
22#include <utility>
23#include <unordered_map>
24
25#include "rw_mutex.hpp"
26#include "z_magic.hpp"
27
28namespace dnnl {
29namespace impl {
30namespace utils {
31
32template <typename T>
33struct thread_local_storage_t {
34 thread_local_storage_t() = default;
35
36 DNNL_DISALLOW_COPY_AND_ASSIGN(thread_local_storage_t);
37
38 template <typename U>
39 T &set(U &&value) {
40 utils::lock_write_t lock_w(mutex_);
41 const auto tid = std::this_thread::get_id();
42 assert(storage_.count(tid) == 0);
43 auto it = storage_.emplace(tid, std::forward<U>(value));
44 assert(it.second);
45 return it.first->second;
46 }
47
48 bool is_set() {
49 utils::lock_read_t lock_r(mutex_);
50 return storage_.find(std::this_thread::get_id()) != storage_.end();
51 }
52
53 T &get() {
54 utils::lock_read_t lock_r(mutex_);
55 return storage_.at(std::this_thread::get_id());
56 }
57
58 template <typename U>
59 T &get(U &&def_value) {
60 {
61 utils::lock_read_t lock_r(mutex_);
62 auto it = storage_.find(std::this_thread::get_id());
63 if (it != storage_.end()) return it->second;
64 }
65 return set(std::forward<U>(def_value));
66 }
67
68private:
69 std::unordered_map<std::thread::id, T> storage_;
70 utils::rw_mutex_t mutex_;
71};
72
73} // namespace utils
74} // namespace impl
75} // namespace dnnl
76
77#endif
78