1// Copyright 2019 The Marl Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#ifndef marl_condition_variable_h
16#define marl_condition_variable_h
17
18#include "containers.h"
19#include "debug.h"
20#include "memory.h"
21#include "mutex.h"
22#include "scheduler.h"
23#include "tsa.h"
24
25#include <atomic>
26#include <condition_variable>
27
28namespace marl {
29
30// ConditionVariable is a synchronization primitive that can be used to block
31// one or more fibers or threads, until another fiber or thread modifies a
32// shared variable (the condition) and notifies the ConditionVariable.
33//
34// If the ConditionVariable is blocked on a thread with a Scheduler bound, the
35// thread will work on other tasks until the ConditionVariable is unblocked.
36class ConditionVariable {
37 public:
38 MARL_NO_EXPORT inline ConditionVariable(
39 Allocator* allocator = Allocator::Default);
40
41 // notify_one() notifies and potentially unblocks one waiting fiber or thread.
42 MARL_NO_EXPORT inline void notify_one();
43
44 // notify_all() notifies and potentially unblocks all waiting fibers and/or
45 // threads.
46 MARL_NO_EXPORT inline void notify_all();
47
48 // wait() blocks the current fiber or thread until the predicate is satisfied
49 // and the ConditionVariable is notified.
50 template <typename Predicate>
51 MARL_NO_EXPORT inline void wait(marl::lock& lock, Predicate&& pred);
52
53 // wait_for() blocks the current fiber or thread until the predicate is
54 // satisfied, and the ConditionVariable is notified, or the timeout has been
55 // reached. Returns false if pred still evaluates to false after the timeout
56 // has been reached, otherwise true.
57 template <typename Rep, typename Period, typename Predicate>
58 MARL_NO_EXPORT inline bool wait_for(
59 marl::lock& lock,
60 const std::chrono::duration<Rep, Period>& duration,
61 Predicate&& pred);
62
63 // wait_until() blocks the current fiber or thread until the predicate is
64 // satisfied, and the ConditionVariable is notified, or the timeout has been
65 // reached. Returns false if pred still evaluates to false after the timeout
66 // has been reached, otherwise true.
67 template <typename Clock, typename Duration, typename Predicate>
68 MARL_NO_EXPORT inline bool wait_until(
69 marl::lock& lock,
70 const std::chrono::time_point<Clock, Duration>& timeout,
71 Predicate&& pred);
72
73 private:
74 ConditionVariable(const ConditionVariable&) = delete;
75 ConditionVariable(ConditionVariable&&) = delete;
76 ConditionVariable& operator=(const ConditionVariable&) = delete;
77 ConditionVariable& operator=(ConditionVariable&&) = delete;
78
79 marl::mutex mutex;
80 containers::list<Scheduler::Fiber*> waiting;
81 std::condition_variable condition;
82 std::atomic<int> numWaiting = {0};
83 std::atomic<int> numWaitingOnCondition = {0};
84};
85
86ConditionVariable::ConditionVariable(
87 Allocator* allocator /* = Allocator::Default */)
88 : waiting(allocator) {}
89
90void ConditionVariable::notify_one() {
91 if (numWaiting == 0) {
92 return;
93 }
94 {
95 marl::lock lock(mutex);
96 if (waiting.size() > 0) {
97 (*waiting.begin())->notify(); // Only wake one fiber.
98 return;
99 }
100 }
101 if (numWaitingOnCondition > 0) {
102 condition.notify_one();
103 }
104}
105
106void ConditionVariable::notify_all() {
107 if (numWaiting == 0) {
108 return;
109 }
110 {
111 marl::lock lock(mutex);
112 for (auto fiber : waiting) {
113 fiber->notify();
114 }
115 }
116 if (numWaitingOnCondition > 0) {
117 condition.notify_all();
118 }
119}
120
121template <typename Predicate>
122void ConditionVariable::wait(marl::lock& lock, Predicate&& pred) {
123 if (pred()) {
124 return;
125 }
126 numWaiting++;
127 if (auto fiber = Scheduler::Fiber::current()) {
128 // Currently executing on a scheduler fiber.
129 // Yield to let other tasks run that can unblock this fiber.
130 mutex.lock();
131 auto it = waiting.emplace_front(fiber);
132 mutex.unlock();
133
134 fiber->wait(lock, pred);
135
136 mutex.lock();
137 waiting.erase(it);
138 mutex.unlock();
139 } else {
140 // Currently running outside of the scheduler.
141 // Delegate to the std::condition_variable.
142 numWaitingOnCondition++;
143 lock.wait(condition, pred);
144 numWaitingOnCondition--;
145 }
146 numWaiting--;
147}
148
149template <typename Rep, typename Period, typename Predicate>
150bool ConditionVariable::wait_for(
151 marl::lock& lock,
152 const std::chrono::duration<Rep, Period>& duration,
153 Predicate&& pred) {
154 return wait_until(lock, std::chrono::system_clock::now() + duration, pred);
155}
156
157template <typename Clock, typename Duration, typename Predicate>
158bool ConditionVariable::wait_until(
159 marl::lock& lock,
160 const std::chrono::time_point<Clock, Duration>& timeout,
161 Predicate&& pred) {
162 if (pred()) {
163 return true;
164 }
165
166 if (auto fiber = Scheduler::Fiber::current()) {
167 numWaiting++;
168
169 // Currently executing on a scheduler fiber.
170 // Yield to let other tasks run that can unblock this fiber.
171 mutex.lock();
172 auto it = waiting.emplace_front(fiber);
173 mutex.unlock();
174
175 auto res = fiber->wait(lock, timeout, pred);
176
177 mutex.lock();
178 waiting.erase(it);
179 mutex.unlock();
180
181 numWaiting--;
182 return res;
183 }
184
185 // Currently running outside of the scheduler.
186 // Delegate to the std::condition_variable.
187 numWaiting++;
188 numWaitingOnCondition++;
189 auto res = lock.wait_until(condition, timeout, pred);
190 numWaitingOnCondition--;
191 numWaiting--;
192 return res;
193}
194
195} // namespace marl
196
197#endif // marl_condition_variable_h
198