1 | // Copyright 2019 The Marl Authors. |
2 | // |
3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | // you may not use this file except in compliance with the License. |
5 | // You may obtain a copy of the License at |
6 | // |
7 | // https://www.apache.org/licenses/LICENSE-2.0 |
8 | // |
9 | // Unless required by applicable law or agreed to in writing, software |
10 | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | // See the License for the specific language governing permissions and |
13 | // limitations under the License. |
14 | |
15 | #ifndef marl_condition_variable_h |
16 | #define marl_condition_variable_h |
17 | |
18 | #include "containers.h" |
19 | #include "debug.h" |
20 | #include "memory.h" |
21 | #include "mutex.h" |
22 | #include "scheduler.h" |
23 | #include "tsa.h" |
24 | |
25 | #include <atomic> |
26 | #include <condition_variable> |
27 | |
28 | namespace marl { |
29 | |
30 | // ConditionVariable is a synchronization primitive that can be used to block |
31 | // one or more fibers or threads, until another fiber or thread modifies a |
32 | // shared variable (the condition) and notifies the ConditionVariable. |
33 | // |
34 | // If the ConditionVariable is blocked on a thread with a Scheduler bound, the |
35 | // thread will work on other tasks until the ConditionVariable is unblocked. |
36 | class ConditionVariable { |
37 | public: |
38 | MARL_NO_EXPORT inline ConditionVariable( |
39 | Allocator* allocator = Allocator::Default); |
40 | |
41 | // notify_one() notifies and potentially unblocks one waiting fiber or thread. |
42 | MARL_NO_EXPORT inline void notify_one(); |
43 | |
44 | // notify_all() notifies and potentially unblocks all waiting fibers and/or |
45 | // threads. |
46 | MARL_NO_EXPORT inline void notify_all(); |
47 | |
48 | // wait() blocks the current fiber or thread until the predicate is satisfied |
49 | // and the ConditionVariable is notified. |
50 | template <typename Predicate> |
51 | MARL_NO_EXPORT inline void wait(marl::lock& lock, Predicate&& pred); |
52 | |
53 | // wait_for() blocks the current fiber or thread until the predicate is |
54 | // satisfied, and the ConditionVariable is notified, or the timeout has been |
55 | // reached. Returns false if pred still evaluates to false after the timeout |
56 | // has been reached, otherwise true. |
57 | template <typename Rep, typename Period, typename Predicate> |
58 | MARL_NO_EXPORT inline bool wait_for( |
59 | marl::lock& lock, |
60 | const std::chrono::duration<Rep, Period>& duration, |
61 | Predicate&& pred); |
62 | |
63 | // wait_until() blocks the current fiber or thread until the predicate is |
64 | // satisfied, and the ConditionVariable is notified, or the timeout has been |
65 | // reached. Returns false if pred still evaluates to false after the timeout |
66 | // has been reached, otherwise true. |
67 | template <typename Clock, typename Duration, typename Predicate> |
68 | MARL_NO_EXPORT inline bool wait_until( |
69 | marl::lock& lock, |
70 | const std::chrono::time_point<Clock, Duration>& timeout, |
71 | Predicate&& pred); |
72 | |
73 | private: |
74 | ConditionVariable(const ConditionVariable&) = delete; |
75 | ConditionVariable(ConditionVariable&&) = delete; |
76 | ConditionVariable& operator=(const ConditionVariable&) = delete; |
77 | ConditionVariable& operator=(ConditionVariable&&) = delete; |
78 | |
79 | marl::mutex mutex; |
80 | containers::list<Scheduler::Fiber*> waiting; |
81 | std::condition_variable condition; |
82 | std::atomic<int> numWaiting = {0}; |
83 | std::atomic<int> numWaitingOnCondition = {0}; |
84 | }; |
85 | |
86 | ConditionVariable::ConditionVariable( |
87 | Allocator* allocator /* = Allocator::Default */) |
88 | : waiting(allocator) {} |
89 | |
90 | void ConditionVariable::notify_one() { |
91 | if (numWaiting == 0) { |
92 | return; |
93 | } |
94 | { |
95 | marl::lock lock(mutex); |
96 | if (waiting.size() > 0) { |
97 | (*waiting.begin())->notify(); // Only wake one fiber. |
98 | return; |
99 | } |
100 | } |
101 | if (numWaitingOnCondition > 0) { |
102 | condition.notify_one(); |
103 | } |
104 | } |
105 | |
106 | void ConditionVariable::notify_all() { |
107 | if (numWaiting == 0) { |
108 | return; |
109 | } |
110 | { |
111 | marl::lock lock(mutex); |
112 | for (auto fiber : waiting) { |
113 | fiber->notify(); |
114 | } |
115 | } |
116 | if (numWaitingOnCondition > 0) { |
117 | condition.notify_all(); |
118 | } |
119 | } |
120 | |
121 | template <typename Predicate> |
122 | void ConditionVariable::wait(marl::lock& lock, Predicate&& pred) { |
123 | if (pred()) { |
124 | return; |
125 | } |
126 | numWaiting++; |
127 | if (auto fiber = Scheduler::Fiber::current()) { |
128 | // Currently executing on a scheduler fiber. |
129 | // Yield to let other tasks run that can unblock this fiber. |
130 | mutex.lock(); |
131 | auto it = waiting.emplace_front(fiber); |
132 | mutex.unlock(); |
133 | |
134 | fiber->wait(lock, pred); |
135 | |
136 | mutex.lock(); |
137 | waiting.erase(it); |
138 | mutex.unlock(); |
139 | } else { |
140 | // Currently running outside of the scheduler. |
141 | // Delegate to the std::condition_variable. |
142 | numWaitingOnCondition++; |
143 | lock.wait(condition, pred); |
144 | numWaitingOnCondition--; |
145 | } |
146 | numWaiting--; |
147 | } |
148 | |
149 | template <typename Rep, typename Period, typename Predicate> |
150 | bool ConditionVariable::wait_for( |
151 | marl::lock& lock, |
152 | const std::chrono::duration<Rep, Period>& duration, |
153 | Predicate&& pred) { |
154 | return wait_until(lock, std::chrono::system_clock::now() + duration, pred); |
155 | } |
156 | |
157 | template <typename Clock, typename Duration, typename Predicate> |
158 | bool ConditionVariable::wait_until( |
159 | marl::lock& lock, |
160 | const std::chrono::time_point<Clock, Duration>& timeout, |
161 | Predicate&& pred) { |
162 | if (pred()) { |
163 | return true; |
164 | } |
165 | |
166 | if (auto fiber = Scheduler::Fiber::current()) { |
167 | numWaiting++; |
168 | |
169 | // Currently executing on a scheduler fiber. |
170 | // Yield to let other tasks run that can unblock this fiber. |
171 | mutex.lock(); |
172 | auto it = waiting.emplace_front(fiber); |
173 | mutex.unlock(); |
174 | |
175 | auto res = fiber->wait(lock, timeout, pred); |
176 | |
177 | mutex.lock(); |
178 | waiting.erase(it); |
179 | mutex.unlock(); |
180 | |
181 | numWaiting--; |
182 | return res; |
183 | } |
184 | |
185 | // Currently running outside of the scheduler. |
186 | // Delegate to the std::condition_variable. |
187 | numWaiting++; |
188 | numWaitingOnCondition++; |
189 | auto res = lock.wait_until(condition, timeout, pred); |
190 | numWaitingOnCondition--; |
191 | numWaiting--; |
192 | return res; |
193 | } |
194 | |
195 | } // namespace marl |
196 | |
197 | #endif // marl_condition_variable_h |
198 | |