1 | // Copyright 2019 The Marl Authors. |
2 | // |
3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | // you may not use this file except in compliance with the License. |
5 | // You may obtain a copy of the License at |
6 | // |
7 | // https://www.apache.org/licenses/LICENSE-2.0 |
8 | // |
9 | // Unless required by applicable law or agreed to in writing, software |
10 | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | // See the License for the specific language governing permissions and |
13 | // limitations under the License. |
14 | |
15 | #ifndef marl_waitgroup_h |
16 | #define marl_waitgroup_h |
17 | |
18 | #include "conditionvariable.h" |
19 | #include "debug.h" |
20 | |
21 | #include <atomic> |
22 | #include <mutex> |
23 | |
24 | namespace marl { |
25 | |
26 | // WaitGroup is a synchronization primitive that holds an internal counter that |
27 | // can incremented, decremented and waited on until it reaches 0. |
28 | // WaitGroups can be used as a simple mechanism for waiting on a number of |
29 | // concurrently execute a number of tasks to complete. |
30 | // |
31 | // Example: |
32 | // |
33 | // void runTasksConcurrently(int numConcurrentTasks) |
34 | // { |
35 | // // Construct the WaitGroup with an initial count of numConcurrentTasks. |
36 | // marl::WaitGroup wg(numConcurrentTasks); |
37 | // for (int i = 0; i < numConcurrentTasks; i++) |
38 | // { |
39 | // // Schedule a task to be run asynchronously. |
40 | // // These may all be run concurrently. |
41 | // marl::schedule([=] { |
42 | // // Once the task has finished, decrement the waitgroup counter |
43 | // // to signal that this has completed. |
44 | // defer(wg.done()); |
45 | // doSomeWork(); |
46 | // }); |
47 | // } |
48 | // // Block until all tasks have completed. |
49 | // wg.wait(); |
50 | // } |
51 | class WaitGroup { |
52 | public: |
53 | // Constructs the WaitGroup with the specified initial count. |
54 | MARL_NO_EXPORT inline WaitGroup(unsigned int initialCount = 0, |
55 | Allocator* allocator = Allocator::Default); |
56 | |
57 | // add() increments the internal counter by count. |
58 | MARL_NO_EXPORT inline void add(unsigned int count = 1) const; |
59 | |
60 | // done() decrements the internal counter by one. |
61 | // Returns true if the internal count has reached zero. |
62 | MARL_NO_EXPORT inline bool done() const; |
63 | |
64 | // wait() blocks until the WaitGroup counter reaches zero. |
65 | MARL_NO_EXPORT inline void wait() const; |
66 | |
67 | private: |
68 | struct Data { |
69 | MARL_NO_EXPORT inline Data(Allocator* allocator); |
70 | |
71 | std::atomic<unsigned int> count = {0}; |
72 | ConditionVariable cv; |
73 | marl::mutex mutex; |
74 | }; |
75 | const std::shared_ptr<Data> data; |
76 | }; |
77 | |
78 | WaitGroup::Data::Data(Allocator* allocator) : cv(allocator) {} |
79 | |
80 | WaitGroup::WaitGroup(unsigned int initialCount /* = 0 */, |
81 | Allocator* allocator /* = Allocator::Default */) |
82 | : data(std::make_shared<Data>(allocator)) { |
83 | data->count = initialCount; |
84 | } |
85 | |
86 | void WaitGroup::add(unsigned int count /* = 1 */) const { |
87 | data->count += count; |
88 | } |
89 | |
90 | bool WaitGroup::done() const { |
91 | MARL_ASSERT(data->count > 0, "marl::WaitGroup::done() called too many times" ); |
92 | auto count = --data->count; |
93 | if (count == 0) { |
94 | marl::lock lock(data->mutex); |
95 | data->cv.notify_all(); |
96 | return true; |
97 | } |
98 | return false; |
99 | } |
100 | |
101 | void WaitGroup::wait() const { |
102 | marl::lock lock(data->mutex); |
103 | data->cv.wait(lock, [this] { return data->count == 0; }); |
104 | } |
105 | |
106 | } // namespace marl |
107 | |
108 | #endif // marl_waitgroup_h |
109 | |