1// Copyright 2019 The Marl Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#ifndef marl_waitgroup_h
16#define marl_waitgroup_h
17
18#include "conditionvariable.h"
19#include "debug.h"
20
21#include <atomic>
22#include <mutex>
23
24namespace marl {
25
26// WaitGroup is a synchronization primitive that holds an internal counter that
27// can incremented, decremented and waited on until it reaches 0.
28// WaitGroups can be used as a simple mechanism for waiting on a number of
29// concurrently execute a number of tasks to complete.
30//
31// Example:
32//
33// void runTasksConcurrently(int numConcurrentTasks)
34// {
35// // Construct the WaitGroup with an initial count of numConcurrentTasks.
36// marl::WaitGroup wg(numConcurrentTasks);
37// for (int i = 0; i < numConcurrentTasks; i++)
38// {
39// // Schedule a task to be run asynchronously.
40// // These may all be run concurrently.
41// marl::schedule([=] {
42// // Once the task has finished, decrement the waitgroup counter
43// // to signal that this has completed.
44// defer(wg.done());
45// doSomeWork();
46// });
47// }
48// // Block until all tasks have completed.
49// wg.wait();
50// }
51class WaitGroup {
52 public:
53 // Constructs the WaitGroup with the specified initial count.
54 MARL_NO_EXPORT inline WaitGroup(unsigned int initialCount = 0,
55 Allocator* allocator = Allocator::Default);
56
57 // add() increments the internal counter by count.
58 MARL_NO_EXPORT inline void add(unsigned int count = 1) const;
59
60 // done() decrements the internal counter by one.
61 // Returns true if the internal count has reached zero.
62 MARL_NO_EXPORT inline bool done() const;
63
64 // wait() blocks until the WaitGroup counter reaches zero.
65 MARL_NO_EXPORT inline void wait() const;
66
67 private:
68 struct Data {
69 MARL_NO_EXPORT inline Data(Allocator* allocator);
70
71 std::atomic<unsigned int> count = {0};
72 ConditionVariable cv;
73 marl::mutex mutex;
74 };
75 const std::shared_ptr<Data> data;
76};
77
78WaitGroup::Data::Data(Allocator* allocator) : cv(allocator) {}
79
80WaitGroup::WaitGroup(unsigned int initialCount /* = 0 */,
81 Allocator* allocator /* = Allocator::Default */)
82 : data(std::make_shared<Data>(allocator)) {
83 data->count = initialCount;
84}
85
86void WaitGroup::add(unsigned int count /* = 1 */) const {
87 data->count += count;
88}
89
90bool WaitGroup::done() const {
91 MARL_ASSERT(data->count > 0, "marl::WaitGroup::done() called too many times");
92 auto count = --data->count;
93 if (count == 0) {
94 marl::lock lock(data->mutex);
95 data->cv.notify_all();
96 return true;
97 }
98 return false;
99}
100
101void WaitGroup::wait() const {
102 marl::lock lock(data->mutex);
103 data->cv.wait(lock, [this] { return data->count == 0; });
104}
105
106} // namespace marl
107
108#endif // marl_waitgroup_h
109