1 | /* Copyright 2019 Google LLC. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "ruy/thread_pool.h" |
17 | |
18 | #include <atomic> |
19 | #include <chrono> // NOLINT(build/c++11) |
20 | #include <condition_variable> // NOLINT(build/c++11) |
21 | #include <cstdint> |
22 | #include <cstdlib> |
23 | #include <memory> |
24 | #include <mutex> // NOLINT(build/c++11) |
25 | #include <thread> // NOLINT(build/c++11) |
26 | |
27 | #include "ruy/check_macros.h" |
28 | #include "ruy/denormal.h" |
29 | #include "ruy/trace.h" |
30 | #include "ruy/wait.h" |
31 | |
32 | namespace ruy { |
33 | |
34 | // A worker thread. |
35 | class Thread { |
36 | public: |
37 | explicit Thread(BlockingCounter* count_busy_threads, Duration spin_duration) |
38 | : state_(State::Startup), |
39 | count_busy_threads_(count_busy_threads), |
40 | spin_duration_(spin_duration) { |
41 | thread_.reset(new std::thread(ThreadFunc, this)); |
42 | } |
43 | |
44 | void RequestExitAsSoonAsPossible() { |
45 | ChangeStateFromOutsideThread(State::ExitAsSoonAsPossible); |
46 | } |
47 | |
48 | ~Thread() { |
49 | RUY_DCHECK_EQ(state_.load(), State::ExitAsSoonAsPossible); |
50 | thread_->join(); |
51 | } |
52 | |
53 | // Called by an outside thead to give work to the worker thread. |
54 | void StartWork(Task* task) { |
55 | ChangeStateFromOutsideThread(State::HasWork, task); |
56 | } |
57 | |
58 | private: |
59 | enum class State { |
60 | Startup, // The initial state before the thread loop runs. |
61 | Ready, // Is not working, has not yet received new work to do. |
62 | HasWork, // Has work to do. |
63 | ExitAsSoonAsPossible // Should exit at earliest convenience. |
64 | }; |
65 | |
66 | // Implements the state_ change to State::Ready, which is where we consume |
67 | // task_. Only called on the worker thread. |
68 | // Reads task_, so assumes ordering past any prior writes to task_. |
69 | void RevertToReadyState() { |
70 | RUY_TRACE_SCOPE_NAME("Worker thread task" ); |
71 | // See task_ member comment for the ordering of accesses. |
72 | if (task_) { |
73 | task_->Run(); |
74 | task_ = nullptr; |
75 | } |
76 | // No need to notify state_cond_, since only the worker thread ever waits |
77 | // on it, and we are that thread. |
78 | // Relaxed order because ordering is already provided by the |
79 | // count_busy_threads_->DecrementCount() at the next line, since the next |
80 | // state_ mutation will be to give new work and that won't happen before |
81 | // the outside thread has finished the current batch with a |
82 | // count_busy_threads_->Wait(). |
83 | state_.store(State::Ready, std::memory_order_relaxed); |
84 | count_busy_threads_->DecrementCount(); |
85 | } |
86 | |
87 | // Changes State, from outside thread. |
88 | // |
89 | // The Task argument is to be used only with new_state==HasWork. |
90 | // It specifies the Task being handed to this Thread. |
91 | // |
92 | // new_task is only used with State::HasWork. |
93 | void ChangeStateFromOutsideThread(State new_state, Task* new_task = nullptr) { |
94 | RUY_DCHECK(new_state == State::ExitAsSoonAsPossible || |
95 | new_state == State::HasWork); |
96 | RUY_DCHECK((new_task != nullptr) == (new_state == State::HasWork)); |
97 | |
98 | #ifndef NDEBUG |
99 | // Debug-only sanity checks based on old_state. |
100 | State old_state = state_.load(); |
101 | RUY_DCHECK_NE(old_state, new_state); |
102 | RUY_DCHECK(old_state == State::Ready || old_state == State::HasWork); |
103 | RUY_DCHECK_NE(old_state, new_state); |
104 | #endif |
105 | |
106 | switch (new_state) { |
107 | case State::HasWork: |
108 | // See task_ member comment for the ordering of accesses. |
109 | RUY_DCHECK(!task_); |
110 | task_ = new_task; |
111 | break; |
112 | case State::ExitAsSoonAsPossible: |
113 | break; |
114 | default: |
115 | abort(); |
116 | } |
117 | // Release order because the worker thread will read this with acquire |
118 | // order. |
119 | state_.store(new_state, std::memory_order_release); |
120 | state_cond_mutex_.lock(); |
121 | state_cond_.notify_one(); // Only this one worker thread cares. |
122 | state_cond_mutex_.unlock(); |
123 | } |
124 | |
125 | static void ThreadFunc(Thread* arg) { arg->ThreadFuncImpl(); } |
126 | |
127 | // Waits for state_ to be different from State::Ready, and returns that |
128 | // new value. |
129 | State GetNewStateOtherThanReady() { |
130 | State new_state; |
131 | const auto& new_state_not_ready = [this, &new_state]() { |
132 | new_state = state_.load(std::memory_order_acquire); |
133 | return new_state != State::Ready; |
134 | }; |
135 | RUY_TRACE_INFO(THREAD_FUNC_IMPL_WAITING); |
136 | Wait(new_state_not_ready, spin_duration_, &state_cond_, &state_cond_mutex_); |
137 | return new_state; |
138 | } |
139 | |
140 | // Thread entry point. |
141 | void ThreadFuncImpl() { |
142 | RUY_TRACE_SCOPE_NAME("Ruy worker thread function" ); |
143 | RevertToReadyState(); |
144 | |
145 | // Suppress denormals to avoid computation inefficiency. |
146 | ScopedSuppressDenormals suppress_denormals; |
147 | |
148 | // Thread loop |
149 | while (GetNewStateOtherThanReady() == State::HasWork) { |
150 | RevertToReadyState(); |
151 | } |
152 | |
153 | // Thread end. We should only get here if we were told to exit. |
154 | RUY_DCHECK(state_.load() == State::ExitAsSoonAsPossible); |
155 | } |
156 | |
157 | // The underlying thread. Used to join on destruction. |
158 | std::unique_ptr<std::thread> thread_; |
159 | |
160 | // The task to be worked on. |
161 | // |
162 | // The ordering of reads and writes to task_ is as follows. |
163 | // |
164 | // 1. The outside thread gives new work by calling |
165 | // ChangeStateFromOutsideThread(State::HasWork, new_task); |
166 | // That does: |
167 | // - a. Write task_ = new_task (non-atomic). |
168 | // - b. Store state_ = State::HasWork (memory_order_release). |
169 | // 2. The worker thread picks up the new state by calling |
170 | // GetNewStateOtherThanReady() |
171 | // That does: |
172 | // - c. Load state (memory_order_acquire). |
173 | // The worker thread then reads the new task in RevertToReadyState(). |
174 | // That does: |
175 | // - d. Read task_ (non-atomic). |
176 | // 3. The worker thread, still in RevertToReadyState(), consumes the task_ and |
177 | // does: |
178 | // - e. Write task_ = nullptr (non-atomic). |
179 | // And then calls Call count_busy_threads_->DecrementCount() |
180 | // which does |
181 | // - f. Store count_busy_threads_ (memory_order_release). |
182 | // 4. The outside thread, in ThreadPool::ExecuteImpl, finally waits for worker |
183 | // threads by calling count_busy_threads_->Wait(), which does: |
184 | // - g. Load count_busy_threads_ (memory_order_acquire). |
185 | // |
186 | // Thus the non-atomic write-then-read accesses to task_ (a. -> d.) are |
187 | // ordered by the release-acquire relationship of accesses to state_ (b. -> |
188 | // c.), and the non-atomic write accesses to task_ (e. -> a.) are ordered by |
189 | // the release-acquire relationship of accesses to count_busy_threads_ (f. -> |
190 | // g.). |
191 | Task* task_ = nullptr; |
192 | |
193 | // Condition variable used by the outside thread to notify the worker thread |
194 | // of a state change. |
195 | std::condition_variable state_cond_; |
196 | |
197 | // Mutex used to guard state_cond_ |
198 | std::mutex state_cond_mutex_; |
199 | |
200 | // The state enum tells if we're currently working, waiting for work, etc. |
201 | // It is written to from either the outside thread or the worker thread, |
202 | // in the ChangeState method. |
203 | // It is only ever read by the worker thread. |
204 | std::atomic<State> state_; |
205 | |
206 | // pointer to the master's thread BlockingCounter object, to notify the |
207 | // master thread of when this thread switches to the 'Ready' state. |
208 | BlockingCounter* const count_busy_threads_; |
209 | |
210 | // See ThreadPool::spin_duration_. |
211 | const Duration spin_duration_; |
212 | }; |
213 | |
214 | void ThreadPool::ExecuteImpl(int task_count, int stride, Task* tasks) { |
215 | RUY_TRACE_SCOPE_NAME("ThreadPool::Execute" ); |
216 | RUY_DCHECK_GE(task_count, 1); |
217 | |
218 | // Case of 1 thread: just run the single task on the current thread. |
219 | if (task_count == 1) { |
220 | (tasks + 0)->Run(); |
221 | return; |
222 | } |
223 | |
224 | // Task #0 will be run on the current thread. |
225 | CreateThreads(task_count - 1); |
226 | count_busy_threads_.Reset(task_count - 1); |
227 | for (int i = 1; i < task_count; i++) { |
228 | RUY_TRACE_INFO(THREADPOOL_EXECUTE_STARTING_TASK); |
229 | auto task_address = reinterpret_cast<std::uintptr_t>(tasks) + i * stride; |
230 | threads_[i - 1]->StartWork(reinterpret_cast<Task*>(task_address)); |
231 | } |
232 | |
233 | RUY_TRACE_INFO(THREADPOOL_EXECUTE_STARTING_TASK_ZERO_ON_CUR_THREAD); |
234 | // Execute task #0 immediately on the current thread. |
235 | (tasks + 0)->Run(); |
236 | |
237 | RUY_TRACE_INFO(THREADPOOL_EXECUTE_WAITING_FOR_THREADS); |
238 | // Wait for the threads submitted above to finish. |
239 | count_busy_threads_.Wait(spin_duration_); |
240 | } |
241 | |
242 | // Ensures that the pool has at least the given count of threads. |
243 | // If any new thread has to be created, this function waits for it to |
244 | // be ready. |
245 | void ThreadPool::CreateThreads(int threads_count) { |
246 | RUY_DCHECK_GE(threads_count, 0); |
247 | unsigned int unsigned_threads_count = threads_count; |
248 | if (threads_.size() >= unsigned_threads_count) { |
249 | return; |
250 | } |
251 | count_busy_threads_.Reset(threads_count - threads_.size()); |
252 | while (threads_.size() < unsigned_threads_count) { |
253 | threads_.push_back(new Thread(&count_busy_threads_, spin_duration_)); |
254 | } |
255 | count_busy_threads_.Wait(spin_duration_); |
256 | } |
257 | |
258 | ThreadPool::~ThreadPool() { |
259 | // Send all exit requests upfront so threads can work on them in parallel. |
260 | for (auto w : threads_) { |
261 | w->RequestExitAsSoonAsPossible(); |
262 | } |
263 | for (auto w : threads_) { |
264 | delete w; |
265 | } |
266 | } |
267 | |
268 | } // end namespace ruy |
269 | |