1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_TSL_PLATFORM_MUTEX_H_
17#define TENSORFLOW_TSL_PLATFORM_MUTEX_H_
18
19#include <chrono> // NOLINT
20// for std::try_to_lock_t and std::cv_status
21#include <condition_variable> // NOLINT
22#include <mutex> // NOLINT
23
24#include "tensorflow/tsl/platform/platform.h"
25#include "tensorflow/tsl/platform/thread_annotations.h"
26#include "tensorflow/tsl/platform/types.h"
27
28// Include appropriate platform-dependent implementation details of mutex etc.
29#if defined(PLATFORM_GOOGLE)
30#include "tensorflow/tsl/platform/google/mutex_data.h"
31#elif defined(PLATFORM_POSIX) || defined(PLATFORM_POSIX_ANDROID) || \
32 defined(PLATFORM_GOOGLE_ANDROID) || defined(PLATFORM_POSIX_IOS) || \
33 defined(PLATFORM_GOOGLE_IOS) || defined(PLATFORM_WINDOWS)
34#include "tensorflow/tsl/platform/default/mutex_data.h"
35#else
36#error Define the appropriate PLATFORM_<foo> macro for this platform
37#endif
38
39namespace tsl {
40
41enum ConditionResult { kCond_Timeout, kCond_MaybeNotified };
42enum LinkerInitialized { LINKER_INITIALIZED };
43
44class condition_variable;
45class Condition;
46
47// Mimic std::mutex + C++17's shared_mutex, adding a LinkerInitialized
48// constructor interface. This type is as fast as mutex, but is also a shared
49// lock, and provides conditional critical sections (via Await()), as an
50// alternative to condition variables.
51class TF_LOCKABLE mutex {
52 public:
53 mutex();
54 // The default implementation of the underlying mutex is safe to use after
55 // the linker initialization to zero.
56 explicit constexpr mutex(LinkerInitialized x)
57 :
58#if defined(PLATFORM_GOOGLE)
59 mu_(absl::kConstInit)
60#else
61 mu_()
62#endif
63 {
64 }
65
66 void lock() TF_EXCLUSIVE_LOCK_FUNCTION();
67 bool try_lock() TF_EXCLUSIVE_TRYLOCK_FUNCTION(true);
68 void unlock() TF_UNLOCK_FUNCTION();
69
70 void lock_shared() TF_SHARED_LOCK_FUNCTION();
71 bool try_lock_shared() TF_SHARED_TRYLOCK_FUNCTION(true);
72 void unlock_shared() TF_UNLOCK_FUNCTION();
73
74 // -------
75 // Conditional critical sections.
76 // These represent an alternative to condition variables that is easier to
77 // use. The predicate must be encapsulated in a function (via Condition),
78 // but there is no need to use a while-loop, and no need to signal the
79 // condition. Example: suppose "mu" protects "counter"; we wish one thread
80 // to wait until counter is decremented to zero by another thread.
81 // // Predicate expressed as a function:
82 // static bool IntIsZero(int* pi) { return *pi == 0; }
83 //
84 // // Waiter:
85 // mu.lock();
86 // mu.Await(Condition(&IntIsZero, &counter)); // no loop needed
87 // // lock is held and counter==0...
88 // mu.unlock();
89 //
90 // // Decrementer:
91 // mu.lock();
92 // counter--;
93 // mu.unlock(); // no need to signal; mutex will check condition
94 //
95 // A mutex may be used with condition variables and conditional critical
96 // sections at the same time. Conditional critical sections are easier to
97 // use, but if there are multiple conditions that are simultaneously false,
98 // condition variables may be faster.
99
100 // Unlock *this and wait until cond.Eval() is true, then atomically reacquire
101 // *this in the same mode in which it was previously held and return.
102 void Await(const Condition& cond);
103
104 // Unlock *this and wait until either cond.Eval is true, or abs_deadline_ns
105 // has been reached, then atomically reacquire *this in the same mode in
106 // which it was previously held, and return whether cond.Eval() is true.
107 // See tsl/tsl/platform/env_time.h for the time interface.
108 bool AwaitWithDeadline(const Condition& cond, uint64 abs_deadline_ns);
109 // -------
110
111 private:
112 friend class condition_variable;
113 internal::MuData mu_;
114};
115
116// A Condition represents a predicate on state protected by a mutex. The
117// function must have no side-effects on that state. When passed to
118// mutex::Await(), the function will be called with the mutex held. It may be
119// called:
120// - any number of times;
121// - by any thread using the mutex; and/or
122// - with the mutex held in any mode (read or write).
123// If you must use a lambda, prefix the lambda with +, and capture no variables.
124// For example: Condition(+[](int *pi)->bool { return *pi == 0; }, &i)
125class Condition {
126 public:
127 template <typename T>
128 Condition(bool (*func)(T* arg), T* arg); // Value is (*func)(arg)
129 template <typename T>
130 Condition(T* obj, bool (T::*method)()); // Value is obj->*method()
131 template <typename T>
132 Condition(T* obj, bool (T::*method)() const); // Value is obj->*method()
133 explicit Condition(const bool* flag); // Value is *flag
134
135 // Return the value of the predicate represented by this Condition.
136 bool Eval() const { return (*this->eval_)(this); }
137
138 private:
139 bool (*eval_)(const Condition*); // CallFunction, CallMethod, or, ReturnBool
140 bool (*function_)(void*); // predicate of form (*function_)(arg_)
141 bool (Condition::*method_)(); // predicate of form arg_->method_()
142 void* arg_;
143 Condition();
144 // The following functions can be pointed to by the eval_ field.
145 template <typename T>
146 static bool CallFunction(const Condition* cond); // call function_
147 template <typename T>
148 static bool CallMethod(const Condition* cond); // call method_
149 static bool ReturnBool(const Condition* cond); // access *(bool *)arg_
150};
151
152// Mimic a subset of the std::unique_lock<tsl::mutex> functionality.
153class TF_SCOPED_LOCKABLE mutex_lock {
154 public:
155 typedef ::tsl::mutex mutex_type;
156
157 explicit mutex_lock(mutex_type& mu) TF_EXCLUSIVE_LOCK_FUNCTION(mu)
158 : mu_(&mu) {
159 mu_->lock();
160 }
161
162 mutex_lock(mutex_type& mu, std::try_to_lock_t) TF_EXCLUSIVE_LOCK_FUNCTION(mu)
163 : mu_(&mu) {
164 if (!mu.try_lock()) {
165 mu_ = nullptr;
166 }
167 }
168
169 // Manually nulls out the source to prevent double-free.
170 // (std::move does not null the source pointer by default.)
171 mutex_lock(mutex_lock&& ml) noexcept TF_EXCLUSIVE_LOCK_FUNCTION(ml.mu_)
172 : mu_(ml.mu_) {
173 ml.mu_ = nullptr;
174 }
175 ~mutex_lock() TF_UNLOCK_FUNCTION() {
176 if (mu_ != nullptr) {
177 mu_->unlock();
178 }
179 }
180 mutex_type* mutex() { return mu_; }
181
182 explicit operator bool() const { return mu_ != nullptr; }
183
184 private:
185 mutex_type* mu_;
186};
187
188// Catch bug where variable name is omitted, e.g. mutex_lock (mu);
189#define mutex_lock(x) static_assert(0, "mutex_lock_decl_missing_var_name");
190
191// Mimic a subset of the std::shared_lock<tsl::mutex> functionality.
192// Name chosen to minimize conflicts with the tf_shared_lock macro, below.
193class TF_SCOPED_LOCKABLE tf_shared_lock {
194 public:
195 typedef ::tsl::mutex mutex_type;
196
197 explicit tf_shared_lock(mutex_type& mu) TF_SHARED_LOCK_FUNCTION(mu)
198 : mu_(&mu) {
199 mu_->lock_shared();
200 }
201
202 tf_shared_lock(mutex_type& mu, std::try_to_lock_t) TF_SHARED_LOCK_FUNCTION(mu)
203 : mu_(&mu) {
204 if (!mu.try_lock_shared()) {
205 mu_ = nullptr;
206 }
207 }
208
209 // Manually nulls out the source to prevent double-free.
210 // (std::move does not null the source pointer by default.)
211 tf_shared_lock(tf_shared_lock&& ml) noexcept TF_SHARED_LOCK_FUNCTION(ml.mu_)
212 : mu_(ml.mu_) {
213 ml.mu_ = nullptr;
214 }
215 ~tf_shared_lock() TF_UNLOCK_FUNCTION() {
216 if (mu_ != nullptr) {
217 mu_->unlock_shared();
218 }
219 }
220 mutex_type* mutex() { return mu_; }
221
222 explicit operator bool() const { return mu_ != nullptr; }
223
224 private:
225 mutex_type* mu_;
226};
227
228// Catch bug where variable name is omitted, e.g. tf_shared_lock (mu);
229#define tf_shared_lock(x) \
230 static_assert(0, "tf_shared_lock_decl_missing_var_name");
231
232// Mimic std::condition_variable.
233class condition_variable {
234 public:
235 condition_variable();
236
237 void wait(mutex_lock& lock);
238 template <class Rep, class Period>
239 std::cv_status wait_for(mutex_lock& lock,
240 std::chrono::duration<Rep, Period> dur);
241 void notify_one();
242 void notify_all();
243
244 private:
245 friend ConditionResult WaitForMilliseconds(mutex_lock* mu,
246 condition_variable* cv,
247 int64_t ms);
248 internal::CVData cv_;
249};
250
251// Like "cv->wait(*mu)", except that it only waits for up to "ms" milliseconds.
252//
253// Returns kCond_Timeout if the timeout expired without this
254// thread noticing a signal on the condition variable. Otherwise may
255// return either kCond_Timeout or kCond_MaybeNotified
256inline ConditionResult WaitForMilliseconds(mutex_lock* mu,
257 condition_variable* cv, int64_t ms) {
258 std::cv_status s = cv->wait_for(*mu, std::chrono::milliseconds(ms));
259 return (s == std::cv_status::timeout) ? kCond_Timeout : kCond_MaybeNotified;
260}
261
262// ------------------------------------------------------------
263// Implementation details follow. Clients should ignore them.
264
265// private static
266template <typename T>
267inline bool Condition::CallFunction(const Condition* cond) {
268 bool (*fn)(T*) = reinterpret_cast<bool (*)(T*)>(cond->function_);
269 return (*fn)(static_cast<T*>(cond->arg_));
270}
271
272template <typename T>
273inline Condition::Condition(bool (*func)(T*), T* arg)
274 : eval_(&CallFunction<T>),
275 function_(reinterpret_cast<bool (*)(void*)>(func)),
276 method_(nullptr),
277 arg_(const_cast<void*>(static_cast<const void*>(arg))) {}
278
279// private static
280template <typename T>
281inline bool Condition::CallMethod(const Condition* cond) {
282 bool (T::*m)() = reinterpret_cast<bool (T::*)()>(cond->method_);
283 return (static_cast<T*>(cond->arg_)->*m)();
284}
285
286template <typename T>
287inline Condition::Condition(T* obj, bool (T::*method)())
288 : eval_(&CallMethod<T>),
289 function_(nullptr),
290 method_(reinterpret_cast<bool (Condition::*)()>(method)),
291 arg_(const_cast<void*>(static_cast<const void*>(obj))) {}
292
293template <typename T>
294inline Condition::Condition(T* obj, bool (T::*method)() const)
295 : eval_(&CallMethod<T>),
296 function_(nullptr),
297 method_(reinterpret_cast<bool (Condition::*)()>(method)),
298 arg_(const_cast<void*>(static_cast<const void*>(obj))) {}
299
300// private static
301inline bool Condition::ReturnBool(const Condition* cond) {
302 return *static_cast<bool*>(cond->arg_);
303}
304
305inline Condition::Condition(const bool* flag)
306 : eval_(&ReturnBool),
307 function_(nullptr),
308 method_(nullptr),
309 arg_(const_cast<void*>(static_cast<const void*>(flag))) {}
310
311} // namespace tsl
312
313// Include appropriate platform-dependent implementation details of mutex etc.
314#if defined(PLATFORM_GOOGLE)
315#include "tensorflow/tsl/platform/google/mutex.h"
316#elif defined(PLATFORM_POSIX) || defined(PLATFORM_POSIX_ANDROID) || \
317 defined(PLATFORM_GOOGLE_ANDROID) || defined(PLATFORM_POSIX_IOS) || \
318 defined(PLATFORM_GOOGLE_IOS) || defined(PLATFORM_WINDOWS)
319#include "tensorflow/tsl/platform/default/mutex.h"
320#else
321#error Define the appropriate PLATFORM_<foo> macro for this platform
322#endif
323
324#endif // TENSORFLOW_TSL_PLATFORM_MUTEX_H_
325