1 | // Copyright 2019 The Marl Authors. |
2 | // |
3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | // you may not use this file except in compliance with the License. |
5 | // You may obtain a copy of the License at |
6 | // |
7 | // https://www.apache.org/licenses/LICENSE-2.0 |
8 | // |
9 | // Unless required by applicable law or agreed to in writing, software |
10 | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | // See the License for the specific language governing permissions and |
13 | // limitations under the License. |
14 | |
15 | #ifndef marl_ticket_h |
16 | #define marl_ticket_h |
17 | |
18 | #include "conditionvariable.h" |
19 | #include "pool.h" |
20 | #include "scheduler.h" |
21 | |
22 | namespace marl { |
23 | |
24 | // Ticket is a synchronization primitive used to serially order execution. |
25 | // |
26 | // Tickets exist in 3 mutually exclusive states: Waiting, Called and Finished. |
27 | // |
28 | // Tickets are obtained from a Ticket::Queue, using the Ticket::Queue::take() |
29 | // methods. The order in which tickets are taken from the queue dictates the |
30 | // order in which they are called. |
31 | // |
32 | // The first ticket to be taken from a queue will be in the 'called' state, |
33 | // subsequent tickets will be in the 'waiting' state. |
34 | // |
35 | // Ticket::wait() will block until the ticket is called. |
36 | // |
37 | // Ticket::done() moves the ticket into the 'finished' state. If all preceeding |
38 | // tickets are finished, done() will call the next unfinished ticket. |
39 | // |
40 | // If the last remaining reference to an unfinished ticket is dropped then |
41 | // done() will be automatically called on that ticket. |
42 | // |
43 | // Example: |
44 | // |
45 | // void runTasksConcurrentThenSerially(int numConcurrentTasks) |
46 | // { |
47 | // marl::Ticket::Queue queue; |
48 | // for (int i = 0; i < numConcurrentTasks; i++) |
49 | // { |
50 | // auto ticket = queue.take(); |
51 | // marl::schedule([=] { |
52 | // doConcurrentWork(); // <- function may be called concurrently |
53 | // ticket.wait(); // <- serialize tasks |
54 | // doSerialWork(); // <- function will not be called concurrently |
55 | // ticket.done(); // <- optional, as done() is called implicitly on |
56 | // // dropping of last reference |
57 | // }); |
58 | // } |
59 | // } |
60 | class Ticket { |
61 | struct Shared; |
62 | struct Record; |
63 | |
64 | public: |
65 | using OnCall = std::function<void()>; |
66 | |
67 | // Queue hands out Tickets. |
68 | class Queue { |
69 | public: |
70 | // take() returns a single ticket from the queue. |
71 | MARL_NO_EXPORT inline Ticket take(); |
72 | |
73 | // take() retrieves count tickets from the queue, calling f() with each |
74 | // retrieved ticket. |
75 | // F must be a function of the signature: void(Ticket&&) |
76 | template <typename F> |
77 | MARL_NO_EXPORT inline void take(size_t count, const F& f); |
78 | |
79 | private: |
80 | std::shared_ptr<Shared> shared = std::make_shared<Shared>(); |
81 | UnboundedPool<Record> pool; |
82 | }; |
83 | |
84 | MARL_NO_EXPORT inline Ticket() = default; |
85 | MARL_NO_EXPORT inline Ticket(const Ticket& other) = default; |
86 | MARL_NO_EXPORT inline Ticket(Ticket&& other) = default; |
87 | MARL_NO_EXPORT inline Ticket& operator=(const Ticket& other) = default; |
88 | |
89 | // wait() blocks until the ticket is called. |
90 | MARL_NO_EXPORT inline void wait() const; |
91 | |
92 | // done() marks the ticket as finished and calls the next ticket. |
93 | MARL_NO_EXPORT inline void done() const; |
94 | |
95 | // onCall() registers the function f to be invoked when this ticket is |
96 | // called. If the ticket is already called prior to calling onCall(), then |
97 | // f() will be executed immediately. |
98 | // F must be a function of the OnCall signature. |
99 | template <typename F> |
100 | MARL_NO_EXPORT inline void onCall(F&& f) const; |
101 | |
102 | private: |
103 | // Internal doubly-linked-list data structure. One per ticket instance. |
104 | struct Record { |
105 | MARL_NO_EXPORT inline ~Record(); |
106 | |
107 | MARL_NO_EXPORT inline void done(); |
108 | MARL_NO_EXPORT inline void callAndUnlock(marl::lock& lock); |
109 | MARL_NO_EXPORT inline void unlink(); // guarded by shared->mutex |
110 | |
111 | ConditionVariable isCalledCondVar; |
112 | |
113 | std::shared_ptr<Shared> shared; |
114 | Record* next = nullptr; // guarded by shared->mutex |
115 | Record* prev = nullptr; // guarded by shared->mutex |
116 | OnCall onCall; // guarded by shared->mutex |
117 | bool isCalled = false; // guarded by shared->mutex |
118 | std::atomic<bool> isDone = {false}; |
119 | }; |
120 | |
121 | // Data shared between all tickets and the queue. |
122 | struct Shared { |
123 | marl::mutex mutex; |
124 | Record tail; |
125 | }; |
126 | |
127 | MARL_NO_EXPORT inline Ticket(Loan<Record>&& record); |
128 | |
129 | Loan<Record> record; |
130 | }; |
131 | |
132 | //////////////////////////////////////////////////////////////////////////////// |
133 | // Ticket |
134 | //////////////////////////////////////////////////////////////////////////////// |
135 | |
136 | Ticket::Ticket(Loan<Record>&& record) : record(std::move(record)) {} |
137 | |
138 | void Ticket::wait() const { |
139 | marl::lock lock(record->shared->mutex); |
140 | record->isCalledCondVar.wait(lock, [this] { return record->isCalled; }); |
141 | } |
142 | |
143 | void Ticket::done() const { |
144 | record->done(); |
145 | } |
146 | |
147 | template <typename Function> |
148 | void Ticket::onCall(Function&& f) const { |
149 | marl::lock lock(record->shared->mutex); |
150 | if (record->isCalled) { |
151 | marl::schedule(std::forward<Function>(f)); |
152 | return; |
153 | } |
154 | if (record->onCall) { |
155 | struct Joined { |
156 | void operator()() const { |
157 | a(); |
158 | b(); |
159 | } |
160 | OnCall a, b; |
161 | }; |
162 | record->onCall = |
163 | std::move(Joined{std::move(record->onCall), std::forward<Function>(f)}); |
164 | } else { |
165 | record->onCall = std::forward<Function>(f); |
166 | } |
167 | } |
168 | |
169 | //////////////////////////////////////////////////////////////////////////////// |
170 | // Ticket::Queue |
171 | //////////////////////////////////////////////////////////////////////////////// |
172 | |
173 | Ticket Ticket::Queue::take() { |
174 | Ticket out; |
175 | take(1, [&](Ticket&& ticket) { out = std::move(ticket); }); |
176 | return out; |
177 | } |
178 | |
179 | template <typename F> |
180 | void Ticket::Queue::take(size_t n, const F& f) { |
181 | Loan<Record> first, last; |
182 | pool.borrow(n, [&](Loan<Record>&& record) { |
183 | Loan<Record> rec = std::move(record); |
184 | rec->shared = shared; |
185 | if (first.get() == nullptr) { |
186 | first = rec; |
187 | } |
188 | if (last.get() != nullptr) { |
189 | last->next = rec.get(); |
190 | rec->prev = last.get(); |
191 | } |
192 | last = rec; |
193 | f(std::move(Ticket(std::move(rec)))); |
194 | }); |
195 | last->next = &shared->tail; |
196 | marl::lock lock(shared->mutex); |
197 | first->prev = shared->tail.prev; |
198 | shared->tail.prev = last.get(); |
199 | if (first->prev == nullptr) { |
200 | first->callAndUnlock(lock); |
201 | } else { |
202 | first->prev->next = first.get(); |
203 | } |
204 | } |
205 | |
206 | //////////////////////////////////////////////////////////////////////////////// |
207 | // Ticket::Record |
208 | //////////////////////////////////////////////////////////////////////////////// |
209 | |
210 | Ticket::Record::~Record() { |
211 | if (shared != nullptr) { |
212 | done(); |
213 | } |
214 | } |
215 | |
216 | void Ticket::Record::done() { |
217 | if (isDone.exchange(true)) { |
218 | return; |
219 | } |
220 | marl::lock lock(shared->mutex); |
221 | auto callNext = (prev == nullptr && next != nullptr) ? next : nullptr; |
222 | unlink(); |
223 | if (callNext != nullptr) { |
224 | // lock needs to be held otherwise callNext might be destructed. |
225 | callNext->callAndUnlock(lock); |
226 | } |
227 | } |
228 | |
229 | void Ticket::Record::callAndUnlock(marl::lock& lock) { |
230 | if (isCalled) { |
231 | return; |
232 | } |
233 | isCalled = true; |
234 | OnCall callback; |
235 | std::swap(callback, onCall); |
236 | isCalledCondVar.notify_all(); |
237 | lock.unlock_no_tsa(); |
238 | |
239 | if (callback) { |
240 | marl::schedule(std::move(callback)); |
241 | } |
242 | } |
243 | |
244 | void Ticket::Record::unlink() { |
245 | if (prev != nullptr) { |
246 | prev->next = next; |
247 | } |
248 | if (next != nullptr) { |
249 | next->prev = prev; |
250 | } |
251 | prev = nullptr; |
252 | next = nullptr; |
253 | } |
254 | |
255 | } // namespace marl |
256 | |
257 | #endif // marl_ticket_h |
258 | |