1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #ifndef GLOW_SUPPORT_THREADPOOL_H |
17 | #define GLOW_SUPPORT_THREADPOOL_H |
18 | |
19 | #include <atomic> |
20 | #include <condition_variable> |
21 | #include <functional> |
22 | #include <future> |
23 | #include <mutex> |
24 | #include <queue> |
25 | #include <set> |
26 | #include <thread> |
27 | #include <vector> |
28 | |
29 | namespace glow { |
30 | |
31 | namespace threads { |
32 | /// Returns a unique id associated with the current thread. |
33 | size_t getThreadId(); |
34 | |
35 | /// Returns a unique id associated with a new virtual thread (i.e. a device |
36 | /// tid). |
37 | size_t createThreadId(); |
38 | } // namespace threads |
39 | |
40 | #ifdef WIN32 |
41 | /// A copyable wrapper for a lambda function that has non-copyable objects in |
42 | /// its lambda capture. |
43 | /// This is useful for VS builds where std::packaged_tasks wraps a |
44 | /// std::function which must be copyable. |
45 | template <class F> struct shared_function { |
46 | std::shared_ptr<F> f; |
47 | shared_function() = delete; |
48 | shared_function(F &&f_) : f(std::make_shared<F>(std::move(f_))) {} |
49 | shared_function(shared_function const &) = default; |
50 | shared_function(shared_function &&) = default; |
51 | shared_function &operator=(shared_function const &) = default; |
52 | shared_function &operator=(shared_function &&) = default; |
53 | template <class... As> auto operator()(As &&...as) const { |
54 | return (*f)(std::forward<As>(as)...); |
55 | } |
56 | }; |
57 | template <class F> |
58 | shared_function<std::decay_t<F>> make_shared_function(F &&f) { |
59 | return {std::forward<F>(f)}; |
60 | } |
61 | #endif |
62 | |
63 | /// An executor that runs Tasks on a single thread. |
64 | class ThreadExecutor final { |
65 | public: |
66 | /// Constructor. Initializes one thread backed by the workQueue_. |
67 | explicit ThreadExecutor(const std::string &name = "" ); |
68 | |
69 | /// Destructor. Signals the thread to stop and waits for exit. |
70 | ~ThreadExecutor(); |
71 | |
72 | /// Submit \p fn as a work item for the thread pool. |
73 | /// \p fn must be a lambda with void return type and arguments. |
74 | template <typename F> std::future<void> submit(F &&fn) { |
75 | // Add fn to the work queue. |
76 | std::unique_lock<std::mutex> lock(workQueueMtx_); |
77 | |
78 | #ifdef WIN32 |
79 | std::packaged_task<void(void)> task(make_shared_function(std::move(fn))); |
80 | #else |
81 | std::packaged_task<void(void)> task(std::move(fn)); |
82 | #endif |
83 | |
84 | auto future = task.get_future(); |
85 | workQueue_.push(std::move(task)); |
86 | lock.unlock(); |
87 | queueNotEmpty_.notify_one(); |
88 | return future; |
89 | } |
90 | |
91 | /// Submit \p task as a work item for the thread pool. |
92 | std::future<void> submit(std::packaged_task<void(void)> &&task); |
93 | |
94 | void stop(bool block = false); |
95 | |
96 | protected: |
97 | /// Main loop run by the workers in the thread pool. |
98 | void threadPoolWorkerMain(); |
99 | |
100 | /// Flag checked in between work items to determine whether we should stop and |
101 | /// exit. |
102 | std::atomic<bool> shouldStop_{false}; |
103 | |
104 | /// Queue of work items. |
105 | std::queue<std::packaged_task<void(void)>> workQueue_; |
106 | |
107 | /// Mutex to coordinate access to the work queue. |
108 | std::mutex workQueueMtx_; |
109 | |
110 | /// Condition variable to signal to threads when work is added to |
111 | /// the work queue. |
112 | std::condition_variable queueNotEmpty_; |
113 | |
114 | /// Worker thread. |
115 | std::thread worker_; |
116 | }; |
117 | |
118 | /// Thread pool for asynchronous execution of generic functions. |
119 | class ThreadPool final { |
120 | public: |
121 | /// Constructor. Initializes a thread pool with \p numWorkers |
122 | /// threads and has them all run ThreadPool::threadPoolWorkerMain. |
123 | ThreadPool(unsigned numWorkers = kNumWorkers, const std::string &name = "" ); |
124 | |
125 | /// Destructor. Signals to all threads to stop and waits for all of them |
126 | /// to exit. |
127 | ~ThreadPool(); |
128 | |
129 | /// Stop all threads and optionally wait for them to join. |
130 | void stop(bool block = false); |
131 | |
132 | /// Submit \p fn as a work item for the thread pool. |
133 | /// \p fn must be a lambda with void return type and arguments. |
134 | template <typename F> std::future<void> submit(F &&fn) { |
135 | #ifdef WIN32 |
136 | std::packaged_task<void(void)> task(make_shared_function(std::move(fn))); |
137 | #else |
138 | std::packaged_task<void(void)> task(std::move(fn)); |
139 | #endif |
140 | |
141 | return submit(std::move(task)); |
142 | } |
143 | |
144 | /// Submit \p task as a work item for the thread pool. |
145 | std::future<void> submit(std::packaged_task<void(void)> &&task); |
146 | |
147 | /// Returns a ThreadExecutor that can be accessed directly, allowing |
148 | /// submitting multiple tasks to the same thread. |
149 | ThreadExecutor *getExecutor() { |
150 | size_t exIndex = nextWorker_++; |
151 | return workers_[exIndex % workers_.size()]; |
152 | } |
153 | |
154 | /// Run the provided function on every thread in the ThreadPool. The function |
155 | /// must be copyable. |
156 | template <typename F> std::future<void> runOnAllThreads(F &&fn) { |
157 | std::shared_ptr<std::atomic<size_t>> finished = |
158 | std::make_shared<std::atomic<size_t>>(0); |
159 | std::shared_ptr<std::promise<void>> promise = |
160 | std::make_shared<std::promise<void>>(); |
161 | for (auto *w : workers_) { |
162 | w->submit([fn, finished, promise, total = workers_.size()]() { |
163 | fn(); |
164 | if ((finished->fetch_add(1) + 1) >= total) { |
165 | promise->set_value(); |
166 | } |
167 | }); |
168 | } |
169 | |
170 | return promise->get_future(); |
171 | } |
172 | |
173 | const std::set<size_t> &getThreadIds() { return threadIds_; } |
174 | |
175 | private: |
176 | /// The default number of workers in the thread pool (overridable). |
177 | constexpr static unsigned kNumWorkers = 10; |
178 | |
179 | /// Vector of worker thread objects. |
180 | /// It is safe to access this without a lock as it is const after |
181 | /// construction. |
182 | std::vector<ThreadExecutor *> workers_; |
183 | |
184 | /// Round robin index for the next work thread. |
185 | std::atomic<size_t> nextWorker_{0}; |
186 | |
187 | /// Thread Ids and associated names owned by this ThreadPool. |
188 | std::set<size_t> threadIds_; |
189 | }; |
190 | } // namespace glow |
191 | |
192 | #endif // GLOW_SUPPORT_THREADPOOL_H |
193 | |