1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_TSL_PLATFORM_MUTEX_H_ |
17 | #define TENSORFLOW_TSL_PLATFORM_MUTEX_H_ |
18 | |
19 | #include <chrono> // NOLINT |
20 | // for std::try_to_lock_t and std::cv_status |
21 | #include <condition_variable> // NOLINT |
22 | #include <mutex> // NOLINT |
23 | |
24 | #include "tensorflow/tsl/platform/platform.h" |
25 | #include "tensorflow/tsl/platform/thread_annotations.h" |
26 | #include "tensorflow/tsl/platform/types.h" |
27 | |
28 | // Include appropriate platform-dependent implementation details of mutex etc. |
29 | #if defined(PLATFORM_GOOGLE) |
30 | #include "tensorflow/tsl/platform/google/mutex_data.h" |
31 | #elif defined(PLATFORM_POSIX) || defined(PLATFORM_POSIX_ANDROID) || \ |
32 | defined(PLATFORM_GOOGLE_ANDROID) || defined(PLATFORM_POSIX_IOS) || \ |
33 | defined(PLATFORM_GOOGLE_IOS) || defined(PLATFORM_WINDOWS) |
34 | #include "tensorflow/tsl/platform/default/mutex_data.h" |
35 | #else |
36 | #error Define the appropriate PLATFORM_<foo> macro for this platform |
37 | #endif |
38 | |
39 | namespace tsl { |
40 | |
41 | enum ConditionResult { kCond_Timeout, kCond_MaybeNotified }; |
42 | enum LinkerInitialized { LINKER_INITIALIZED }; |
43 | |
44 | class condition_variable; |
45 | class Condition; |
46 | |
47 | // Mimic std::mutex + C++17's shared_mutex, adding a LinkerInitialized |
48 | // constructor interface. This type is as fast as mutex, but is also a shared |
49 | // lock, and provides conditional critical sections (via Await()), as an |
50 | // alternative to condition variables. |
51 | class TF_LOCKABLE mutex { |
52 | public: |
53 | mutex(); |
54 | // The default implementation of the underlying mutex is safe to use after |
55 | // the linker initialization to zero. |
56 | explicit constexpr mutex(LinkerInitialized x) |
57 | : |
58 | #if defined(PLATFORM_GOOGLE) |
59 | mu_(absl::kConstInit) |
60 | #else |
61 | mu_() |
62 | #endif |
63 | { |
64 | } |
65 | |
66 | void lock() TF_EXCLUSIVE_LOCK_FUNCTION(); |
67 | bool try_lock() TF_EXCLUSIVE_TRYLOCK_FUNCTION(true); |
68 | void unlock() TF_UNLOCK_FUNCTION(); |
69 | |
70 | void lock_shared() TF_SHARED_LOCK_FUNCTION(); |
71 | bool try_lock_shared() TF_SHARED_TRYLOCK_FUNCTION(true); |
72 | void unlock_shared() TF_UNLOCK_FUNCTION(); |
73 | |
74 | // ------- |
75 | // Conditional critical sections. |
76 | // These represent an alternative to condition variables that is easier to |
77 | // use. The predicate must be encapsulated in a function (via Condition), |
78 | // but there is no need to use a while-loop, and no need to signal the |
79 | // condition. Example: suppose "mu" protects "counter"; we wish one thread |
80 | // to wait until counter is decremented to zero by another thread. |
81 | // // Predicate expressed as a function: |
82 | // static bool IntIsZero(int* pi) { return *pi == 0; } |
83 | // |
84 | // // Waiter: |
85 | // mu.lock(); |
86 | // mu.Await(Condition(&IntIsZero, &counter)); // no loop needed |
87 | // // lock is held and counter==0... |
88 | // mu.unlock(); |
89 | // |
90 | // // Decrementer: |
91 | // mu.lock(); |
92 | // counter--; |
93 | // mu.unlock(); // no need to signal; mutex will check condition |
94 | // |
95 | // A mutex may be used with condition variables and conditional critical |
96 | // sections at the same time. Conditional critical sections are easier to |
97 | // use, but if there are multiple conditions that are simultaneously false, |
98 | // condition variables may be faster. |
99 | |
100 | // Unlock *this and wait until cond.Eval() is true, then atomically reacquire |
101 | // *this in the same mode in which it was previously held and return. |
102 | void Await(const Condition& cond); |
103 | |
104 | // Unlock *this and wait until either cond.Eval is true, or abs_deadline_ns |
105 | // has been reached, then atomically reacquire *this in the same mode in |
106 | // which it was previously held, and return whether cond.Eval() is true. |
107 | // See tsl/tsl/platform/env_time.h for the time interface. |
108 | bool AwaitWithDeadline(const Condition& cond, uint64 abs_deadline_ns); |
109 | // ------- |
110 | |
111 | private: |
112 | friend class condition_variable; |
113 | internal::MuData mu_; |
114 | }; |
115 | |
116 | // A Condition represents a predicate on state protected by a mutex. The |
117 | // function must have no side-effects on that state. When passed to |
118 | // mutex::Await(), the function will be called with the mutex held. It may be |
119 | // called: |
120 | // - any number of times; |
121 | // - by any thread using the mutex; and/or |
122 | // - with the mutex held in any mode (read or write). |
123 | // If you must use a lambda, prefix the lambda with +, and capture no variables. |
124 | // For example: Condition(+[](int *pi)->bool { return *pi == 0; }, &i) |
125 | class Condition { |
126 | public: |
127 | template <typename T> |
128 | Condition(bool (*func)(T* arg), T* arg); // Value is (*func)(arg) |
129 | template <typename T> |
130 | Condition(T* obj, bool (T::*method)()); // Value is obj->*method() |
131 | template <typename T> |
132 | Condition(T* obj, bool (T::*method)() const); // Value is obj->*method() |
133 | explicit Condition(const bool* flag); // Value is *flag |
134 | |
135 | // Return the value of the predicate represented by this Condition. |
136 | bool Eval() const { return (*this->eval_)(this); } |
137 | |
138 | private: |
139 | bool (*eval_)(const Condition*); // CallFunction, CallMethod, or, ReturnBool |
140 | bool (*function_)(void*); // predicate of form (*function_)(arg_) |
141 | bool (Condition::*method_)(); // predicate of form arg_->method_() |
142 | void* arg_; |
143 | Condition(); |
144 | // The following functions can be pointed to by the eval_ field. |
145 | template <typename T> |
146 | static bool CallFunction(const Condition* cond); // call function_ |
147 | template <typename T> |
148 | static bool CallMethod(const Condition* cond); // call method_ |
149 | static bool ReturnBool(const Condition* cond); // access *(bool *)arg_ |
150 | }; |
151 | |
152 | // Mimic a subset of the std::unique_lock<tsl::mutex> functionality. |
153 | class TF_SCOPED_LOCKABLE mutex_lock { |
154 | public: |
155 | typedef ::tsl::mutex mutex_type; |
156 | |
157 | explicit mutex_lock(mutex_type& mu) TF_EXCLUSIVE_LOCK_FUNCTION(mu) |
158 | : mu_(&mu) { |
159 | mu_->lock(); |
160 | } |
161 | |
162 | mutex_lock(mutex_type& mu, std::try_to_lock_t) TF_EXCLUSIVE_LOCK_FUNCTION(mu) |
163 | : mu_(&mu) { |
164 | if (!mu.try_lock()) { |
165 | mu_ = nullptr; |
166 | } |
167 | } |
168 | |
169 | // Manually nulls out the source to prevent double-free. |
170 | // (std::move does not null the source pointer by default.) |
171 | mutex_lock(mutex_lock&& ml) noexcept TF_EXCLUSIVE_LOCK_FUNCTION(ml.mu_) |
172 | : mu_(ml.mu_) { |
173 | ml.mu_ = nullptr; |
174 | } |
175 | ~mutex_lock() TF_UNLOCK_FUNCTION() { |
176 | if (mu_ != nullptr) { |
177 | mu_->unlock(); |
178 | } |
179 | } |
180 | mutex_type* mutex() { return mu_; } |
181 | |
182 | explicit operator bool() const { return mu_ != nullptr; } |
183 | |
184 | private: |
185 | mutex_type* mu_; |
186 | }; |
187 | |
188 | // Catch bug where variable name is omitted, e.g. mutex_lock (mu); |
189 | #define mutex_lock(x) static_assert(0, "mutex_lock_decl_missing_var_name"); |
190 | |
191 | // Mimic a subset of the std::shared_lock<tsl::mutex> functionality. |
192 | // Name chosen to minimize conflicts with the tf_shared_lock macro, below. |
193 | class TF_SCOPED_LOCKABLE tf_shared_lock { |
194 | public: |
195 | typedef ::tsl::mutex mutex_type; |
196 | |
197 | explicit tf_shared_lock(mutex_type& mu) TF_SHARED_LOCK_FUNCTION(mu) |
198 | : mu_(&mu) { |
199 | mu_->lock_shared(); |
200 | } |
201 | |
202 | tf_shared_lock(mutex_type& mu, std::try_to_lock_t) TF_SHARED_LOCK_FUNCTION(mu) |
203 | : mu_(&mu) { |
204 | if (!mu.try_lock_shared()) { |
205 | mu_ = nullptr; |
206 | } |
207 | } |
208 | |
209 | // Manually nulls out the source to prevent double-free. |
210 | // (std::move does not null the source pointer by default.) |
211 | tf_shared_lock(tf_shared_lock&& ml) noexcept TF_SHARED_LOCK_FUNCTION(ml.mu_) |
212 | : mu_(ml.mu_) { |
213 | ml.mu_ = nullptr; |
214 | } |
215 | ~tf_shared_lock() TF_UNLOCK_FUNCTION() { |
216 | if (mu_ != nullptr) { |
217 | mu_->unlock_shared(); |
218 | } |
219 | } |
220 | mutex_type* mutex() { return mu_; } |
221 | |
222 | explicit operator bool() const { return mu_ != nullptr; } |
223 | |
224 | private: |
225 | mutex_type* mu_; |
226 | }; |
227 | |
228 | // Catch bug where variable name is omitted, e.g. tf_shared_lock (mu); |
229 | #define tf_shared_lock(x) \ |
230 | static_assert(0, "tf_shared_lock_decl_missing_var_name"); |
231 | |
232 | // Mimic std::condition_variable. |
233 | class condition_variable { |
234 | public: |
235 | condition_variable(); |
236 | |
237 | void wait(mutex_lock& lock); |
238 | template <class Rep, class Period> |
239 | std::cv_status wait_for(mutex_lock& lock, |
240 | std::chrono::duration<Rep, Period> dur); |
241 | void notify_one(); |
242 | void notify_all(); |
243 | |
244 | private: |
245 | friend ConditionResult WaitForMilliseconds(mutex_lock* mu, |
246 | condition_variable* cv, |
247 | int64_t ms); |
248 | internal::CVData cv_; |
249 | }; |
250 | |
251 | // Like "cv->wait(*mu)", except that it only waits for up to "ms" milliseconds. |
252 | // |
253 | // Returns kCond_Timeout if the timeout expired without this |
254 | // thread noticing a signal on the condition variable. Otherwise may |
255 | // return either kCond_Timeout or kCond_MaybeNotified |
256 | inline ConditionResult WaitForMilliseconds(mutex_lock* mu, |
257 | condition_variable* cv, int64_t ms) { |
258 | std::cv_status s = cv->wait_for(*mu, std::chrono::milliseconds(ms)); |
259 | return (s == std::cv_status::timeout) ? kCond_Timeout : kCond_MaybeNotified; |
260 | } |
261 | |
262 | // ------------------------------------------------------------ |
263 | // Implementation details follow. Clients should ignore them. |
264 | |
265 | // private static |
266 | template <typename T> |
267 | inline bool Condition::CallFunction(const Condition* cond) { |
268 | bool (*fn)(T*) = reinterpret_cast<bool (*)(T*)>(cond->function_); |
269 | return (*fn)(static_cast<T*>(cond->arg_)); |
270 | } |
271 | |
272 | template <typename T> |
273 | inline Condition::Condition(bool (*func)(T*), T* arg) |
274 | : eval_(&CallFunction<T>), |
275 | function_(reinterpret_cast<bool (*)(void*)>(func)), |
276 | method_(nullptr), |
277 | arg_(const_cast<void*>(static_cast<const void*>(arg))) {} |
278 | |
279 | // private static |
280 | template <typename T> |
281 | inline bool Condition::CallMethod(const Condition* cond) { |
282 | bool (T::*m)() = reinterpret_cast<bool (T::*)()>(cond->method_); |
283 | return (static_cast<T*>(cond->arg_)->*m)(); |
284 | } |
285 | |
286 | template <typename T> |
287 | inline Condition::Condition(T* obj, bool (T::*method)()) |
288 | : eval_(&CallMethod<T>), |
289 | function_(nullptr), |
290 | method_(reinterpret_cast<bool (Condition::*)()>(method)), |
291 | arg_(const_cast<void*>(static_cast<const void*>(obj))) {} |
292 | |
293 | template <typename T> |
294 | inline Condition::Condition(T* obj, bool (T::*method)() const) |
295 | : eval_(&CallMethod<T>), |
296 | function_(nullptr), |
297 | method_(reinterpret_cast<bool (Condition::*)()>(method)), |
298 | arg_(const_cast<void*>(static_cast<const void*>(obj))) {} |
299 | |
300 | // private static |
301 | inline bool Condition::ReturnBool(const Condition* cond) { |
302 | return *static_cast<bool*>(cond->arg_); |
303 | } |
304 | |
305 | inline Condition::Condition(const bool* flag) |
306 | : eval_(&ReturnBool), |
307 | function_(nullptr), |
308 | method_(nullptr), |
309 | arg_(const_cast<void*>(static_cast<const void*>(flag))) {} |
310 | |
311 | } // namespace tsl |
312 | |
313 | // Include appropriate platform-dependent implementation details of mutex etc. |
314 | #if defined(PLATFORM_GOOGLE) |
315 | #include "tensorflow/tsl/platform/google/mutex.h" |
316 | #elif defined(PLATFORM_POSIX) || defined(PLATFORM_POSIX_ANDROID) || \ |
317 | defined(PLATFORM_GOOGLE_ANDROID) || defined(PLATFORM_POSIX_IOS) || \ |
318 | defined(PLATFORM_GOOGLE_IOS) || defined(PLATFORM_WINDOWS) |
319 | #include "tensorflow/tsl/platform/default/mutex.h" |
320 | #else |
321 | #error Define the appropriate PLATFORM_<foo> macro for this platform |
322 | #endif |
323 | |
324 | #endif // TENSORFLOW_TSL_PLATFORM_MUTEX_H_ |
325 | |