1// Copyright 2019 The Marl Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#ifndef marl_scheduler_h
16#define marl_scheduler_h
17
18#include "containers.h"
19#include "debug.h"
20#include "deprecated.h"
21#include "export.h"
22#include "memory.h"
23#include "mutex.h"
24#include "sanitizers.h"
25#include "task.h"
26#include "thread.h"
27
28#include <array>
29#include <atomic>
30#include <chrono>
31#include <condition_variable>
32#include <functional>
33#include <thread>
34
35namespace marl {
36
37class OSFiber;
38
39// Scheduler asynchronously processes Tasks.
40// A scheduler can be bound to one or more threads using the bind() method.
41// Once bound to a thread, that thread can call marl::schedule() to enqueue
42// work tasks to be executed asynchronously.
43// Scheduler are initially constructed in single-threaded mode.
44// Call setWorkerThreadCount() to spawn dedicated worker threads.
45class Scheduler {
46 class Worker;
47
48 public:
49 using TimePoint = std::chrono::system_clock::time_point;
50 using Predicate = std::function<bool()>;
51 using ThreadInitializer = std::function<void(int workerId)>;
52
53 // Config holds scheduler configuration settings that can be passed to the
54 // Scheduler constructor.
55 struct Config {
56 static constexpr size_t DefaultFiberStackSize = 1024 * 1024;
57
58 // Per-worker-thread settings.
59 struct WorkerThread {
60 // Total number of dedicated worker threads to spawn for the scheduler.
61 int count = 0;
62
63 // Initializer function to call after thread creation and before any work
64 // is run by the thread.
65 ThreadInitializer initializer;
66
67 // Thread affinity policy to use for worker threads.
68 std::shared_ptr<Thread::Affinity::Policy> affinityPolicy;
69 };
70
71 WorkerThread workerThread;
72
73 // Memory allocator to use for the scheduler and internal allocations.
74 Allocator* allocator = Allocator::Default;
75
76 // Size of each fiber stack. This may be rounded up to the nearest
77 // allocation granularity for the given platform.
78 size_t fiberStackSize = DefaultFiberStackSize;
79
80 // allCores() returns a Config with a worker thread for each of the logical
81 // cpus available to the process.
82 MARL_EXPORT
83 static Config allCores();
84
85 // Fluent setters that return this Config so set calls can be chained.
86 MARL_NO_EXPORT inline Config& setAllocator(Allocator*);
87 MARL_NO_EXPORT inline Config& setFiberStackSize(size_t);
88 MARL_NO_EXPORT inline Config& setWorkerThreadCount(int);
89 MARL_NO_EXPORT inline Config& setWorkerThreadInitializer(
90 const ThreadInitializer&);
91 MARL_NO_EXPORT inline Config& setWorkerThreadAffinityPolicy(
92 const std::shared_ptr<Thread::Affinity::Policy>&);
93 };
94
95 // Constructor.
96 MARL_EXPORT
97 Scheduler(const Config&);
98
99 // Destructor.
100 // Blocks until the scheduler is unbound from all threads before returning.
101 MARL_EXPORT
102 ~Scheduler();
103
104 // get() returns the scheduler bound to the current thread.
105 MARL_EXPORT
106 static Scheduler* get();
107
108 // bind() binds this scheduler to the current thread.
109 // There must be no existing scheduler bound to the thread prior to calling.
110 MARL_EXPORT
111 void bind();
112
113 // unbind() unbinds the scheduler currently bound to the current thread.
114 // There must be an existing scheduler bound to the thread prior to calling.
115 // unbind() flushes any enqueued tasks on the single-threaded worker before
116 // returning.
117 MARL_EXPORT
118 static void unbind();
119
120 // enqueue() queues the task for asynchronous execution.
121 MARL_EXPORT
122 void enqueue(Task&& task);
123
124 // config() returns the Config that was used to build the scheduler.
125 MARL_EXPORT
126 const Config& config() const;
127
128 // Fibers expose methods to perform cooperative multitasking and are
129 // automatically created by the Scheduler.
130 //
131 // The currently executing Fiber can be obtained by calling Fiber::current().
132 //
133 // When execution becomes blocked, yield() can be called to suspend execution
134 // of the fiber and start executing other pending work. Once the block has
135 // been lifted, schedule() can be called to reschedule the Fiber on the same
136 // thread that previously executed it.
137 class Fiber {
138 public:
139 // current() returns the currently executing fiber, or nullptr if called
140 // without a bound scheduler.
141 MARL_EXPORT
142 static Fiber* current();
143
144 // wait() suspends execution of this Fiber until the Fiber is woken up with
145 // a call to notify() and the predicate pred returns true.
146 // If the predicate pred does not return true when notify() is called, then
147 // the Fiber is automatically re-suspended, and will need to be woken with
148 // another call to notify().
149 // While the Fiber is suspended, the scheduler thread may continue executing
150 // other tasks.
151 // lock must be locked before calling, and is unlocked by wait() just before
152 // the Fiber is suspended, and re-locked before the fiber is resumed. lock
153 // will be locked before wait() returns.
154 // pred will be always be called with the lock held.
155 // wait() must only be called on the currently executing fiber.
156 MARL_EXPORT
157 void wait(marl::lock& lock, const Predicate& pred);
158
159 // wait() suspends execution of this Fiber until the Fiber is woken up with
160 // a call to notify() and the predicate pred returns true, or sometime after
161 // the timeout is reached.
162 // If the predicate pred does not return true when notify() is called, then
163 // the Fiber is automatically re-suspended, and will need to be woken with
164 // another call to notify() or will be woken sometime after the timeout is
165 // reached.
166 // While the Fiber is suspended, the scheduler thread may continue executing
167 // other tasks.
168 // lock must be locked before calling, and is unlocked by wait() just before
169 // the Fiber is suspended, and re-locked before the fiber is resumed. lock
170 // will be locked before wait() returns.
171 // pred will be always be called with the lock held.
172 // wait() must only be called on the currently executing fiber.
173 template <typename Clock, typename Duration>
174 MARL_NO_EXPORT inline bool wait(
175 marl::lock& lock,
176 const std::chrono::time_point<Clock, Duration>& timeout,
177 const Predicate& pred);
178
179 // wait() suspends execution of this Fiber until the Fiber is woken up with
180 // a call to notify().
181 // While the Fiber is suspended, the scheduler thread may continue executing
182 // other tasks.
183 // wait() must only be called on the currently executing fiber.
184 //
185 // Warning: Unlike wait() overloads that take a lock and predicate, this
186 // form of wait() offers no safety for notify() signals that occur before
187 // the fiber is suspended, when signalling between different threads. In
188 // this scenario you may deadlock. For this reason, it is only ever
189 // recommended to use this overload if you can guarantee that the calls to
190 // wait() and notify() are made by the same thread.
191 //
192 // Use with extreme caution.
193 MARL_NO_EXPORT inline void wait();
194
195 // wait() suspends execution of this Fiber until the Fiber is woken up with
196 // a call to notify(), or sometime after the timeout is reached.
197 // While the Fiber is suspended, the scheduler thread may continue executing
198 // other tasks.
199 // wait() must only be called on the currently executing fiber.
200 //
201 // Warning: Unlike wait() overloads that take a lock and predicate, this
202 // form of wait() offers no safety for notify() signals that occur before
203 // the fiber is suspended, when signalling between different threads. For
204 // this reason, it is only ever recommended to use this overload if you can
205 // guarantee that the calls to wait() and notify() are made by the same
206 // thread.
207 //
208 // Use with extreme caution.
209 template <typename Clock, typename Duration>
210 MARL_NO_EXPORT inline bool wait(
211 const std::chrono::time_point<Clock, Duration>& timeout);
212
213 // notify() reschedules the suspended Fiber for execution.
214 // notify() is usually only called when the predicate for one or more wait()
215 // calls will likely return true.
216 MARL_EXPORT
217 void notify();
218
219 // id is the thread-unique identifier of the Fiber.
220 uint32_t const id;
221
222 private:
223 friend class Allocator;
224 friend class Scheduler;
225
226 enum class State {
227 // Idle: the Fiber is currently unused, and sits in Worker::idleFibers,
228 // ready to be recycled.
229 Idle,
230
231 // Yielded: the Fiber is currently blocked on a wait() call with no
232 // timeout.
233 Yielded,
234
235 // Waiting: the Fiber is currently blocked on a wait() call with a
236 // timeout. The fiber is stilling in the Worker::Work::waiting queue.
237 Waiting,
238
239 // Queued: the Fiber is currently queued for execution in the
240 // Worker::Work::fibers queue.
241 Queued,
242
243 // Running: the Fiber is currently executing.
244 Running,
245 };
246
247 Fiber(Allocator::unique_ptr<OSFiber>&&, uint32_t id);
248
249 // switchTo() switches execution to the given fiber.
250 // switchTo() must only be called on the currently executing fiber.
251 void switchTo(Fiber*);
252
253 // create() constructs and returns a new fiber with the given identifier,
254 // stack size and func that will be executed when switched to.
255 static Allocator::unique_ptr<Fiber> create(
256 Allocator* allocator,
257 uint32_t id,
258 size_t stackSize,
259 const std::function<void()>& func);
260
261 // createFromCurrentThread() constructs and returns a new fiber with the
262 // given identifier for the current thread.
263 static Allocator::unique_ptr<Fiber> createFromCurrentThread(
264 Allocator* allocator,
265 uint32_t id);
266
267 // toString() returns a string representation of the given State.
268 // Used for debugging.
269 static const char* toString(State state);
270
271 Allocator::unique_ptr<OSFiber> const impl;
272 Worker* const worker;
273 State state = State::Running; // Guarded by Worker's work.mutex.
274 };
275
276 private:
277 Scheduler(const Scheduler&) = delete;
278 Scheduler(Scheduler&&) = delete;
279 Scheduler& operator=(const Scheduler&) = delete;
280 Scheduler& operator=(Scheduler&&) = delete;
281
282 // Maximum number of worker threads.
283 static constexpr size_t MaxWorkerThreads = 256;
284
285 // WaitingFibers holds all the fibers waiting on a timeout.
286 struct WaitingFibers {
287 inline WaitingFibers(Allocator*);
288
289 // operator bool() returns true iff there are any wait fibers.
290 inline operator bool() const;
291
292 // take() returns the next fiber that has exceeded its timeout, or nullptr
293 // if there are no fibers that have yet exceeded their timeouts.
294 inline Fiber* take(const TimePoint& timeout);
295
296 // next() returns the timepoint of the next fiber to timeout.
297 // next() can only be called if operator bool() returns true.
298 inline TimePoint next() const;
299
300 // add() adds another fiber and timeout to the list of waiting fibers.
301 inline void add(const TimePoint& timeout, Fiber* fiber);
302
303 // erase() removes the fiber from the waiting list.
304 inline void erase(Fiber* fiber);
305
306 // contains() returns true if fiber is waiting.
307 inline bool contains(Fiber* fiber) const;
308
309 private:
310 struct Timeout {
311 TimePoint timepoint;
312 Fiber* fiber;
313 inline bool operator<(const Timeout&) const;
314 };
315 containers::set<Timeout, std::less<Timeout>> timeouts;
316 containers::unordered_map<Fiber*, TimePoint> fibers;
317 };
318
319 // TODO: Implement a queue that recycles elements to reduce number of
320 // heap allocations.
321 using TaskQueue = containers::deque<Task>;
322 using FiberQueue = containers::deque<Fiber*>;
323 using FiberSet = containers::unordered_set<Fiber*>;
324
325 // Workers execute Tasks on a single thread.
326 // Once a task is started, it may yield to other tasks on the same Worker.
327 // Tasks are always resumed by the same Worker.
328 class Worker {
329 public:
330 enum class Mode {
331 // Worker will spawn a background thread to process tasks.
332 MultiThreaded,
333
334 // Worker will execute tasks whenever it yields.
335 SingleThreaded,
336 };
337
338 Worker(Scheduler* scheduler, Mode mode, uint32_t id);
339
340 // start() begins execution of the worker.
341 void start() EXCLUDES(work.mutex);
342
343 // stop() ceases execution of the worker, blocking until all pending
344 // tasks have fully finished.
345 void stop() EXCLUDES(work.mutex);
346
347 // wait() suspends execution of the current task until the predicate pred
348 // returns true or the optional timeout is reached.
349 // See Fiber::wait() for more information.
350 MARL_EXPORT
351 bool wait(marl::lock& lock, const TimePoint* timeout, const Predicate& pred)
352 EXCLUDES(work.mutex);
353
354 // wait() suspends execution of the current task until the fiber is
355 // notified, or the optional timeout is reached.
356 // See Fiber::wait() for more information.
357 MARL_EXPORT
358 bool wait(const TimePoint* timeout) EXCLUDES(work.mutex);
359
360 // suspend() suspends the currently executing Fiber until the fiber is
361 // woken with a call to enqueue(Fiber*), or automatically sometime after the
362 // optional timeout.
363 void suspend(const TimePoint* timeout) REQUIRES(work.mutex);
364
365 // enqueue(Fiber*) enqueues resuming of a suspended fiber.
366 void enqueue(Fiber* fiber) EXCLUDES(work.mutex);
367
368 // enqueue(Task&&) enqueues a new, unstarted task.
369 void enqueue(Task&& task) EXCLUDES(work.mutex);
370
371 // tryLock() attempts to lock the worker for task enqueuing.
372 // If the lock was successful then true is returned, and the caller must
373 // call enqueueAndUnlock().
374 bool tryLock() EXCLUDES(work.mutex) TRY_ACQUIRE(true, work.mutex);
375
376 // enqueueAndUnlock() enqueues the task and unlocks the worker.
377 // Must only be called after a call to tryLock() which returned true.
378 // _Releases_lock_(work.mutex)
379 void enqueueAndUnlock(Task&& task) REQUIRES(work.mutex) RELEASE(work.mutex);
380
381 // runUntilShutdown() processes all tasks and fibers until there are no more
382 // and shutdown is true, upon runUntilShutdown() returns.
383 void runUntilShutdown() REQUIRES(work.mutex);
384
385 // steal() attempts to steal a Task from the worker for another worker.
386 // Returns true if a task was taken and assigned to out, otherwise false.
387 bool steal(Task& out) EXCLUDES(work.mutex);
388
389 // getCurrent() returns the Worker currently bound to the current
390 // thread.
391 static inline Worker* getCurrent();
392
393 // getCurrentFiber() returns the Fiber currently being executed.
394 inline Fiber* getCurrentFiber() const;
395
396 // Unique identifier of the Worker.
397 const uint32_t id;
398
399 private:
400 // run() is the task processing function for the worker.
401 // run() processes tasks until stop() is called.
402 void run() REQUIRES(work.mutex);
403
404 // createWorkerFiber() creates a new fiber that when executed calls
405 // run().
406 Fiber* createWorkerFiber() REQUIRES(work.mutex);
407
408 // switchToFiber() switches execution to the given fiber. The fiber
409 // must belong to this worker.
410 void switchToFiber(Fiber*) REQUIRES(work.mutex);
411
412 // runUntilIdle() executes all pending tasks and then returns.
413 void runUntilIdle() REQUIRES(work.mutex);
414
415 // waitForWork() blocks until new work is available, potentially calling
416 // spinForWork().
417 void waitForWork() REQUIRES(work.mutex);
418
419 // spinForWork() attempts to steal work from another Worker, and keeps
420 // the thread awake for a short duration. This reduces overheads of
421 // frequently putting the thread to sleep and re-waking.
422 void spinForWork();
423
424 // enqueueFiberTimeouts() enqueues all the fibers that have finished
425 // waiting.
426 void enqueueFiberTimeouts() REQUIRES(work.mutex);
427
428 inline void changeFiberState(Fiber* fiber,
429 Fiber::State from,
430 Fiber::State to) const REQUIRES(work.mutex);
431
432 inline void setFiberState(Fiber* fiber, Fiber::State to) const
433 REQUIRES(work.mutex);
434
435 // Work holds tasks and fibers that are enqueued on the Worker.
436 struct Work {
437 inline Work(Allocator*);
438
439 std::atomic<uint64_t> num = {0}; // tasks.size() + fibers.size()
440 GUARDED_BY(mutex) uint64_t numBlockedFibers = 0;
441 GUARDED_BY(mutex) TaskQueue tasks;
442 GUARDED_BY(mutex) FiberQueue fibers;
443 GUARDED_BY(mutex) WaitingFibers waiting;
444 GUARDED_BY(mutex) bool notifyAdded = true;
445 std::condition_variable added;
446 marl::mutex mutex;
447
448 template <typename F>
449 inline void wait(F&&) REQUIRES(mutex);
450 };
451
452 // https://en.wikipedia.org/wiki/Xorshift
453 class FastRnd {
454 public:
455 inline uint64_t operator()() {
456 x ^= x << 13;
457 x ^= x >> 7;
458 x ^= x << 17;
459 return x;
460 }
461
462 private:
463 uint64_t x = std::chrono::system_clock::now().time_since_epoch().count();
464 };
465
466 // The current worker bound to the current thread.
467 static thread_local Worker* current;
468
469 Mode const mode;
470 Scheduler* const scheduler;
471 Allocator::unique_ptr<Fiber> mainFiber;
472 Fiber* currentFiber = nullptr;
473 Thread thread;
474 Work work;
475 FiberSet idleFibers; // Fibers that have completed which can be reused.
476 containers::vector<Allocator::unique_ptr<Fiber>, 16>
477 workerFibers; // All fibers created by this worker.
478 FastRnd rng;
479 bool shutdown = false;
480 };
481
482 // stealWork() attempts to steal a task from the worker with the given id.
483 // Returns true if a task was stolen and assigned to out, otherwise false.
484 bool stealWork(Worker* thief, uint64_t from, Task& out);
485
486 // onBeginSpinning() is called when a Worker calls spinForWork().
487 // The scheduler will prioritize this worker for new tasks to try to prevent
488 // it going to sleep.
489 void onBeginSpinning(int workerId);
490
491 // setBound() sets the scheduler bound to the current thread.
492 static void setBound(Scheduler* scheduler);
493
494 // The scheduler currently bound to the current thread.
495 static thread_local Scheduler* bound;
496
497 // The immutable configuration used to build the scheduler.
498 const Config cfg;
499
500 std::array<std::atomic<int>, 8> spinningWorkers;
501 std::atomic<unsigned int> nextSpinningWorkerIdx = {0x8000000};
502
503 std::atomic<unsigned int> nextEnqueueIndex = {0};
504 std::array<Worker*, MaxWorkerThreads> workerThreads;
505
506 struct SingleThreadedWorkers {
507 inline SingleThreadedWorkers(Allocator*);
508
509 using WorkerByTid =
510 containers::unordered_map<std::thread::id,
511 Allocator::unique_ptr<Worker>>;
512 marl::mutex mutex;
513 GUARDED_BY(mutex) std::condition_variable unbind;
514 GUARDED_BY(mutex) WorkerByTid byTid;
515 };
516 SingleThreadedWorkers singleThreadedWorkers;
517};
518
519////////////////////////////////////////////////////////////////////////////////
520// Scheduler::Config
521////////////////////////////////////////////////////////////////////////////////
522Scheduler::Config& Scheduler::Config::setAllocator(Allocator* alloc) {
523 allocator = alloc;
524 return *this;
525}
526
527Scheduler::Config& Scheduler::Config::setFiberStackSize(size_t size) {
528 fiberStackSize = size;
529 return *this;
530}
531
532Scheduler::Config& Scheduler::Config::setWorkerThreadCount(int count) {
533 workerThread.count = count;
534 return *this;
535}
536
537Scheduler::Config& Scheduler::Config::setWorkerThreadInitializer(
538 const ThreadInitializer& initializer) {
539 workerThread.initializer = initializer;
540 return *this;
541}
542
543Scheduler::Config& Scheduler::Config::setWorkerThreadAffinityPolicy(
544 const std::shared_ptr<Thread::Affinity::Policy>& policy) {
545 workerThread.affinityPolicy = policy;
546 return *this;
547}
548
549////////////////////////////////////////////////////////////////////////////////
550// Scheduler::Fiber
551////////////////////////////////////////////////////////////////////////////////
552template <typename Clock, typename Duration>
553bool Scheduler::Fiber::wait(
554 marl::lock& lock,
555 const std::chrono::time_point<Clock, Duration>& timeout,
556 const Predicate& pred) {
557 using ToDuration = typename TimePoint::duration;
558 using ToClock = typename TimePoint::clock;
559 auto tp = std::chrono::time_point_cast<ToDuration, ToClock>(timeout);
560 return worker->wait(lock, &tp, pred);
561}
562
563void Scheduler::Fiber::wait() {
564 worker->wait(nullptr);
565}
566
567template <typename Clock, typename Duration>
568bool Scheduler::Fiber::wait(
569 const std::chrono::time_point<Clock, Duration>& timeout) {
570 using ToDuration = typename TimePoint::duration;
571 using ToClock = typename TimePoint::clock;
572 auto tp = std::chrono::time_point_cast<ToDuration, ToClock>(timeout);
573 return worker->wait(&tp);
574}
575
576Scheduler::Worker* Scheduler::Worker::getCurrent() {
577 MSAN_UNPOISON(&current, sizeof(Worker*));
578 return Worker::current;
579}
580
581Scheduler::Fiber* Scheduler::Worker::getCurrentFiber() const {
582 return currentFiber;
583}
584
585// schedule() schedules the task T to be asynchronously called using the
586// currently bound scheduler.
587inline void schedule(Task&& t) {
588 MARL_ASSERT_HAS_BOUND_SCHEDULER("marl::schedule");
589 auto scheduler = Scheduler::get();
590 scheduler->enqueue(std::move(t));
591}
592
593// schedule() schedules the function f to be asynchronously called with the
594// given arguments using the currently bound scheduler.
595template <typename Function, typename... Args>
596inline void schedule(Function&& f, Args&&... args) {
597 MARL_ASSERT_HAS_BOUND_SCHEDULER("marl::schedule");
598 auto scheduler = Scheduler::get();
599 scheduler->enqueue(
600 Task(std::bind(std::forward<Function>(f), std::forward<Args>(args)...)));
601}
602
603// schedule() schedules the function f to be asynchronously called using the
604// currently bound scheduler.
605template <typename Function>
606inline void schedule(Function&& f) {
607 MARL_ASSERT_HAS_BOUND_SCHEDULER("marl::schedule");
608 auto scheduler = Scheduler::get();
609 scheduler->enqueue(Task(std::forward<Function>(f)));
610}
611
612} // namespace marl
613
614#endif // marl_scheduler_h
615