1// Copyright 2019 The Marl Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#ifndef marl_ticket_h
16#define marl_ticket_h
17
18#include "conditionvariable.h"
19#include "pool.h"
20#include "scheduler.h"
21
22namespace marl {
23
24// Ticket is a synchronization primitive used to serially order execution.
25//
26// Tickets exist in 3 mutually exclusive states: Waiting, Called and Finished.
27//
28// Tickets are obtained from a Ticket::Queue, using the Ticket::Queue::take()
29// methods. The order in which tickets are taken from the queue dictates the
30// order in which they are called.
31//
32// The first ticket to be taken from a queue will be in the 'called' state,
33// subsequent tickets will be in the 'waiting' state.
34//
35// Ticket::wait() will block until the ticket is called.
36//
37// Ticket::done() moves the ticket into the 'finished' state. If all preceeding
38// tickets are finished, done() will call the next unfinished ticket.
39//
40// If the last remaining reference to an unfinished ticket is dropped then
41// done() will be automatically called on that ticket.
42//
43// Example:
44//
45// void runTasksConcurrentThenSerially(int numConcurrentTasks)
46// {
47// marl::Ticket::Queue queue;
48// for (int i = 0; i < numConcurrentTasks; i++)
49// {
50// auto ticket = queue.take();
51// marl::schedule([=] {
52// doConcurrentWork(); // <- function may be called concurrently
53// ticket.wait(); // <- serialize tasks
54// doSerialWork(); // <- function will not be called concurrently
55// ticket.done(); // <- optional, as done() is called implicitly on
56// // dropping of last reference
57// });
58// }
59// }
60class Ticket {
61 struct Shared;
62 struct Record;
63
64 public:
65 using OnCall = std::function<void()>;
66
67 // Queue hands out Tickets.
68 class Queue {
69 public:
70 // take() returns a single ticket from the queue.
71 MARL_NO_EXPORT inline Ticket take();
72
73 // take() retrieves count tickets from the queue, calling f() with each
74 // retrieved ticket.
75 // F must be a function of the signature: void(Ticket&&)
76 template <typename F>
77 MARL_NO_EXPORT inline void take(size_t count, const F& f);
78
79 private:
80 std::shared_ptr<Shared> shared = std::make_shared<Shared>();
81 UnboundedPool<Record> pool;
82 };
83
84 MARL_NO_EXPORT inline Ticket() = default;
85 MARL_NO_EXPORT inline Ticket(const Ticket& other) = default;
86 MARL_NO_EXPORT inline Ticket(Ticket&& other) = default;
87 MARL_NO_EXPORT inline Ticket& operator=(const Ticket& other) = default;
88
89 // wait() blocks until the ticket is called.
90 MARL_NO_EXPORT inline void wait() const;
91
92 // done() marks the ticket as finished and calls the next ticket.
93 MARL_NO_EXPORT inline void done() const;
94
95 // onCall() registers the function f to be invoked when this ticket is
96 // called. If the ticket is already called prior to calling onCall(), then
97 // f() will be executed immediately.
98 // F must be a function of the OnCall signature.
99 template <typename F>
100 MARL_NO_EXPORT inline void onCall(F&& f) const;
101
102 private:
103 // Internal doubly-linked-list data structure. One per ticket instance.
104 struct Record {
105 MARL_NO_EXPORT inline ~Record();
106
107 MARL_NO_EXPORT inline void done();
108 MARL_NO_EXPORT inline void callAndUnlock(marl::lock& lock);
109 MARL_NO_EXPORT inline void unlink(); // guarded by shared->mutex
110
111 ConditionVariable isCalledCondVar;
112
113 std::shared_ptr<Shared> shared;
114 Record* next = nullptr; // guarded by shared->mutex
115 Record* prev = nullptr; // guarded by shared->mutex
116 OnCall onCall; // guarded by shared->mutex
117 bool isCalled = false; // guarded by shared->mutex
118 std::atomic<bool> isDone = {false};
119 };
120
121 // Data shared between all tickets and the queue.
122 struct Shared {
123 marl::mutex mutex;
124 Record tail;
125 };
126
127 MARL_NO_EXPORT inline Ticket(Loan<Record>&& record);
128
129 Loan<Record> record;
130};
131
132////////////////////////////////////////////////////////////////////////////////
133// Ticket
134////////////////////////////////////////////////////////////////////////////////
135
136Ticket::Ticket(Loan<Record>&& record) : record(std::move(record)) {}
137
138void Ticket::wait() const {
139 marl::lock lock(record->shared->mutex);
140 record->isCalledCondVar.wait(lock, [this] { return record->isCalled; });
141}
142
143void Ticket::done() const {
144 record->done();
145}
146
147template <typename Function>
148void Ticket::onCall(Function&& f) const {
149 marl::lock lock(record->shared->mutex);
150 if (record->isCalled) {
151 marl::schedule(std::forward<Function>(f));
152 return;
153 }
154 if (record->onCall) {
155 struct Joined {
156 void operator()() const {
157 a();
158 b();
159 }
160 OnCall a, b;
161 };
162 record->onCall =
163 std::move(Joined{std::move(record->onCall), std::forward<Function>(f)});
164 } else {
165 record->onCall = std::forward<Function>(f);
166 }
167}
168
169////////////////////////////////////////////////////////////////////////////////
170// Ticket::Queue
171////////////////////////////////////////////////////////////////////////////////
172
173Ticket Ticket::Queue::take() {
174 Ticket out;
175 take(1, [&](Ticket&& ticket) { out = std::move(ticket); });
176 return out;
177}
178
179template <typename F>
180void Ticket::Queue::take(size_t n, const F& f) {
181 Loan<Record> first, last;
182 pool.borrow(n, [&](Loan<Record>&& record) {
183 Loan<Record> rec = std::move(record);
184 rec->shared = shared;
185 if (first.get() == nullptr) {
186 first = rec;
187 }
188 if (last.get() != nullptr) {
189 last->next = rec.get();
190 rec->prev = last.get();
191 }
192 last = rec;
193 f(std::move(Ticket(std::move(rec))));
194 });
195 last->next = &shared->tail;
196 marl::lock lock(shared->mutex);
197 first->prev = shared->tail.prev;
198 shared->tail.prev = last.get();
199 if (first->prev == nullptr) {
200 first->callAndUnlock(lock);
201 } else {
202 first->prev->next = first.get();
203 }
204}
205
206////////////////////////////////////////////////////////////////////////////////
207// Ticket::Record
208////////////////////////////////////////////////////////////////////////////////
209
210Ticket::Record::~Record() {
211 if (shared != nullptr) {
212 done();
213 }
214}
215
216void Ticket::Record::done() {
217 if (isDone.exchange(true)) {
218 return;
219 }
220 marl::lock lock(shared->mutex);
221 auto callNext = (prev == nullptr && next != nullptr) ? next : nullptr;
222 unlink();
223 if (callNext != nullptr) {
224 // lock needs to be held otherwise callNext might be destructed.
225 callNext->callAndUnlock(lock);
226 }
227}
228
229void Ticket::Record::callAndUnlock(marl::lock& lock) {
230 if (isCalled) {
231 return;
232 }
233 isCalled = true;
234 OnCall callback;
235 std::swap(callback, onCall);
236 isCalledCondVar.notify_all();
237 lock.unlock_no_tsa();
238
239 if (callback) {
240 marl::schedule(std::move(callback));
241 }
242}
243
244void Ticket::Record::unlink() {
245 if (prev != nullptr) {
246 prev->next = next;
247 }
248 if (next != nullptr) {
249 next->prev = prev;
250 }
251 prev = nullptr;
252 next = nullptr;
253}
254
255} // namespace marl
256
257#endif // marl_ticket_h
258