1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#ifndef GLOW_SUPPORT_THREADPOOL_H
17#define GLOW_SUPPORT_THREADPOOL_H
18
19#include <atomic>
20#include <condition_variable>
21#include <functional>
22#include <future>
23#include <mutex>
24#include <queue>
25#include <set>
26#include <thread>
27#include <vector>
28
29namespace glow {
30
31namespace threads {
32/// Returns a unique id associated with the current thread.
33size_t getThreadId();
34
35/// Returns a unique id associated with a new virtual thread (i.e. a device
36/// tid).
37size_t createThreadId();
38} // namespace threads
39
40#ifdef WIN32
41/// A copyable wrapper for a lambda function that has non-copyable objects in
42/// its lambda capture.
43/// This is useful for VS builds where std::packaged_tasks wraps a
44/// std::function which must be copyable.
45template <class F> struct shared_function {
46 std::shared_ptr<F> f;
47 shared_function() = delete;
48 shared_function(F &&f_) : f(std::make_shared<F>(std::move(f_))) {}
49 shared_function(shared_function const &) = default;
50 shared_function(shared_function &&) = default;
51 shared_function &operator=(shared_function const &) = default;
52 shared_function &operator=(shared_function &&) = default;
53 template <class... As> auto operator()(As &&...as) const {
54 return (*f)(std::forward<As>(as)...);
55 }
56};
57template <class F>
58shared_function<std::decay_t<F>> make_shared_function(F &&f) {
59 return {std::forward<F>(f)};
60}
61#endif
62
63/// An executor that runs Tasks on a single thread.
64class ThreadExecutor final {
65public:
66 /// Constructor. Initializes one thread backed by the workQueue_.
67 explicit ThreadExecutor(const std::string &name = "");
68
69 /// Destructor. Signals the thread to stop and waits for exit.
70 ~ThreadExecutor();
71
72 /// Submit \p fn as a work item for the thread pool.
73 /// \p fn must be a lambda with void return type and arguments.
74 template <typename F> std::future<void> submit(F &&fn) {
75 // Add fn to the work queue.
76 std::unique_lock<std::mutex> lock(workQueueMtx_);
77
78#ifdef WIN32
79 std::packaged_task<void(void)> task(make_shared_function(std::move(fn)));
80#else
81 std::packaged_task<void(void)> task(std::move(fn));
82#endif
83
84 auto future = task.get_future();
85 workQueue_.push(std::move(task));
86 lock.unlock();
87 queueNotEmpty_.notify_one();
88 return future;
89 }
90
91 /// Submit \p task as a work item for the thread pool.
92 std::future<void> submit(std::packaged_task<void(void)> &&task);
93
94 void stop(bool block = false);
95
96protected:
97 /// Main loop run by the workers in the thread pool.
98 void threadPoolWorkerMain();
99
100 /// Flag checked in between work items to determine whether we should stop and
101 /// exit.
102 std::atomic<bool> shouldStop_{false};
103
104 /// Queue of work items.
105 std::queue<std::packaged_task<void(void)>> workQueue_;
106
107 /// Mutex to coordinate access to the work queue.
108 std::mutex workQueueMtx_;
109
110 /// Condition variable to signal to threads when work is added to
111 /// the work queue.
112 std::condition_variable queueNotEmpty_;
113
114 /// Worker thread.
115 std::thread worker_;
116};
117
118/// Thread pool for asynchronous execution of generic functions.
119class ThreadPool final {
120public:
121 /// Constructor. Initializes a thread pool with \p numWorkers
122 /// threads and has them all run ThreadPool::threadPoolWorkerMain.
123 ThreadPool(unsigned numWorkers = kNumWorkers, const std::string &name = "");
124
125 /// Destructor. Signals to all threads to stop and waits for all of them
126 /// to exit.
127 ~ThreadPool();
128
129 /// Stop all threads and optionally wait for them to join.
130 void stop(bool block = false);
131
132 /// Submit \p fn as a work item for the thread pool.
133 /// \p fn must be a lambda with void return type and arguments.
134 template <typename F> std::future<void> submit(F &&fn) {
135#ifdef WIN32
136 std::packaged_task<void(void)> task(make_shared_function(std::move(fn)));
137#else
138 std::packaged_task<void(void)> task(std::move(fn));
139#endif
140
141 return submit(std::move(task));
142 }
143
144 /// Submit \p task as a work item for the thread pool.
145 std::future<void> submit(std::packaged_task<void(void)> &&task);
146
147 /// Returns a ThreadExecutor that can be accessed directly, allowing
148 /// submitting multiple tasks to the same thread.
149 ThreadExecutor *getExecutor() {
150 size_t exIndex = nextWorker_++;
151 return workers_[exIndex % workers_.size()];
152 }
153
154 /// Run the provided function on every thread in the ThreadPool. The function
155 /// must be copyable.
156 template <typename F> std::future<void> runOnAllThreads(F &&fn) {
157 std::shared_ptr<std::atomic<size_t>> finished =
158 std::make_shared<std::atomic<size_t>>(0);
159 std::shared_ptr<std::promise<void>> promise =
160 std::make_shared<std::promise<void>>();
161 for (auto *w : workers_) {
162 w->submit([fn, finished, promise, total = workers_.size()]() {
163 fn();
164 if ((finished->fetch_add(1) + 1) >= total) {
165 promise->set_value();
166 }
167 });
168 }
169
170 return promise->get_future();
171 }
172
173 const std::set<size_t> &getThreadIds() { return threadIds_; }
174
175private:
176 /// The default number of workers in the thread pool (overridable).
177 constexpr static unsigned kNumWorkers = 10;
178
179 /// Vector of worker thread objects.
180 /// It is safe to access this without a lock as it is const after
181 /// construction.
182 std::vector<ThreadExecutor *> workers_;
183
184 /// Round robin index for the next work thread.
185 std::atomic<size_t> nextWorker_{0};
186
187 /// Thread Ids and associated names owned by this ThreadPool.
188 std::set<size_t> threadIds_;
189};
190} // namespace glow
191
192#endif // GLOW_SUPPORT_THREADPOOL_H
193