1/* Copyright 2019 Google LLC. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "ruy/thread_pool.h"
17
18#include <atomic>
19#include <chrono> // NOLINT(build/c++11)
20#include <condition_variable> // NOLINT(build/c++11)
21#include <cstdint>
22#include <cstdlib>
23#include <memory>
24#include <mutex> // NOLINT(build/c++11)
25#include <thread> // NOLINT(build/c++11)
26
27#include "ruy/check_macros.h"
28#include "ruy/denormal.h"
29#include "ruy/trace.h"
30#include "ruy/wait.h"
31
32namespace ruy {
33
34// A worker thread.
35class Thread {
36 public:
37 explicit Thread(BlockingCounter* count_busy_threads, Duration spin_duration)
38 : state_(State::Startup),
39 count_busy_threads_(count_busy_threads),
40 spin_duration_(spin_duration) {
41 thread_.reset(new std::thread(ThreadFunc, this));
42 }
43
44 void RequestExitAsSoonAsPossible() {
45 ChangeStateFromOutsideThread(State::ExitAsSoonAsPossible);
46 }
47
48 ~Thread() {
49 RUY_DCHECK_EQ(state_.load(), State::ExitAsSoonAsPossible);
50 thread_->join();
51 }
52
53 // Called by an outside thead to give work to the worker thread.
54 void StartWork(Task* task) {
55 ChangeStateFromOutsideThread(State::HasWork, task);
56 }
57
58 private:
59 enum class State {
60 Startup, // The initial state before the thread loop runs.
61 Ready, // Is not working, has not yet received new work to do.
62 HasWork, // Has work to do.
63 ExitAsSoonAsPossible // Should exit at earliest convenience.
64 };
65
66 // Implements the state_ change to State::Ready, which is where we consume
67 // task_. Only called on the worker thread.
68 // Reads task_, so assumes ordering past any prior writes to task_.
69 void RevertToReadyState() {
70 RUY_TRACE_SCOPE_NAME("Worker thread task");
71 // See task_ member comment for the ordering of accesses.
72 if (task_) {
73 task_->Run();
74 task_ = nullptr;
75 }
76 // No need to notify state_cond_, since only the worker thread ever waits
77 // on it, and we are that thread.
78 // Relaxed order because ordering is already provided by the
79 // count_busy_threads_->DecrementCount() at the next line, since the next
80 // state_ mutation will be to give new work and that won't happen before
81 // the outside thread has finished the current batch with a
82 // count_busy_threads_->Wait().
83 state_.store(State::Ready, std::memory_order_relaxed);
84 count_busy_threads_->DecrementCount();
85 }
86
87 // Changes State, from outside thread.
88 //
89 // The Task argument is to be used only with new_state==HasWork.
90 // It specifies the Task being handed to this Thread.
91 //
92 // new_task is only used with State::HasWork.
93 void ChangeStateFromOutsideThread(State new_state, Task* new_task = nullptr) {
94 RUY_DCHECK(new_state == State::ExitAsSoonAsPossible ||
95 new_state == State::HasWork);
96 RUY_DCHECK((new_task != nullptr) == (new_state == State::HasWork));
97
98#ifndef NDEBUG
99 // Debug-only sanity checks based on old_state.
100 State old_state = state_.load();
101 RUY_DCHECK_NE(old_state, new_state);
102 RUY_DCHECK(old_state == State::Ready || old_state == State::HasWork);
103 RUY_DCHECK_NE(old_state, new_state);
104#endif
105
106 switch (new_state) {
107 case State::HasWork:
108 // See task_ member comment for the ordering of accesses.
109 RUY_DCHECK(!task_);
110 task_ = new_task;
111 break;
112 case State::ExitAsSoonAsPossible:
113 break;
114 default:
115 abort();
116 }
117 // Release order because the worker thread will read this with acquire
118 // order.
119 state_.store(new_state, std::memory_order_release);
120 state_cond_mutex_.lock();
121 state_cond_.notify_one(); // Only this one worker thread cares.
122 state_cond_mutex_.unlock();
123 }
124
125 static void ThreadFunc(Thread* arg) { arg->ThreadFuncImpl(); }
126
127 // Waits for state_ to be different from State::Ready, and returns that
128 // new value.
129 State GetNewStateOtherThanReady() {
130 State new_state;
131 const auto& new_state_not_ready = [this, &new_state]() {
132 new_state = state_.load(std::memory_order_acquire);
133 return new_state != State::Ready;
134 };
135 RUY_TRACE_INFO(THREAD_FUNC_IMPL_WAITING);
136 Wait(new_state_not_ready, spin_duration_, &state_cond_, &state_cond_mutex_);
137 return new_state;
138 }
139
140 // Thread entry point.
141 void ThreadFuncImpl() {
142 RUY_TRACE_SCOPE_NAME("Ruy worker thread function");
143 RevertToReadyState();
144
145 // Suppress denormals to avoid computation inefficiency.
146 ScopedSuppressDenormals suppress_denormals;
147
148 // Thread loop
149 while (GetNewStateOtherThanReady() == State::HasWork) {
150 RevertToReadyState();
151 }
152
153 // Thread end. We should only get here if we were told to exit.
154 RUY_DCHECK(state_.load() == State::ExitAsSoonAsPossible);
155 }
156
157 // The underlying thread. Used to join on destruction.
158 std::unique_ptr<std::thread> thread_;
159
160 // The task to be worked on.
161 //
162 // The ordering of reads and writes to task_ is as follows.
163 //
164 // 1. The outside thread gives new work by calling
165 // ChangeStateFromOutsideThread(State::HasWork, new_task);
166 // That does:
167 // - a. Write task_ = new_task (non-atomic).
168 // - b. Store state_ = State::HasWork (memory_order_release).
169 // 2. The worker thread picks up the new state by calling
170 // GetNewStateOtherThanReady()
171 // That does:
172 // - c. Load state (memory_order_acquire).
173 // The worker thread then reads the new task in RevertToReadyState().
174 // That does:
175 // - d. Read task_ (non-atomic).
176 // 3. The worker thread, still in RevertToReadyState(), consumes the task_ and
177 // does:
178 // - e. Write task_ = nullptr (non-atomic).
179 // And then calls Call count_busy_threads_->DecrementCount()
180 // which does
181 // - f. Store count_busy_threads_ (memory_order_release).
182 // 4. The outside thread, in ThreadPool::ExecuteImpl, finally waits for worker
183 // threads by calling count_busy_threads_->Wait(), which does:
184 // - g. Load count_busy_threads_ (memory_order_acquire).
185 //
186 // Thus the non-atomic write-then-read accesses to task_ (a. -> d.) are
187 // ordered by the release-acquire relationship of accesses to state_ (b. ->
188 // c.), and the non-atomic write accesses to task_ (e. -> a.) are ordered by
189 // the release-acquire relationship of accesses to count_busy_threads_ (f. ->
190 // g.).
191 Task* task_ = nullptr;
192
193 // Condition variable used by the outside thread to notify the worker thread
194 // of a state change.
195 std::condition_variable state_cond_;
196
197 // Mutex used to guard state_cond_
198 std::mutex state_cond_mutex_;
199
200 // The state enum tells if we're currently working, waiting for work, etc.
201 // It is written to from either the outside thread or the worker thread,
202 // in the ChangeState method.
203 // It is only ever read by the worker thread.
204 std::atomic<State> state_;
205
206 // pointer to the master's thread BlockingCounter object, to notify the
207 // master thread of when this thread switches to the 'Ready' state.
208 BlockingCounter* const count_busy_threads_;
209
210 // See ThreadPool::spin_duration_.
211 const Duration spin_duration_;
212};
213
214void ThreadPool::ExecuteImpl(int task_count, int stride, Task* tasks) {
215 RUY_TRACE_SCOPE_NAME("ThreadPool::Execute");
216 RUY_DCHECK_GE(task_count, 1);
217
218 // Case of 1 thread: just run the single task on the current thread.
219 if (task_count == 1) {
220 (tasks + 0)->Run();
221 return;
222 }
223
224 // Task #0 will be run on the current thread.
225 CreateThreads(task_count - 1);
226 count_busy_threads_.Reset(task_count - 1);
227 for (int i = 1; i < task_count; i++) {
228 RUY_TRACE_INFO(THREADPOOL_EXECUTE_STARTING_TASK);
229 auto task_address = reinterpret_cast<std::uintptr_t>(tasks) + i * stride;
230 threads_[i - 1]->StartWork(reinterpret_cast<Task*>(task_address));
231 }
232
233 RUY_TRACE_INFO(THREADPOOL_EXECUTE_STARTING_TASK_ZERO_ON_CUR_THREAD);
234 // Execute task #0 immediately on the current thread.
235 (tasks + 0)->Run();
236
237 RUY_TRACE_INFO(THREADPOOL_EXECUTE_WAITING_FOR_THREADS);
238 // Wait for the threads submitted above to finish.
239 count_busy_threads_.Wait(spin_duration_);
240}
241
242// Ensures that the pool has at least the given count of threads.
243// If any new thread has to be created, this function waits for it to
244// be ready.
245void ThreadPool::CreateThreads(int threads_count) {
246 RUY_DCHECK_GE(threads_count, 0);
247 unsigned int unsigned_threads_count = threads_count;
248 if (threads_.size() >= unsigned_threads_count) {
249 return;
250 }
251 count_busy_threads_.Reset(threads_count - threads_.size());
252 while (threads_.size() < unsigned_threads_count) {
253 threads_.push_back(new Thread(&count_busy_threads_, spin_duration_));
254 }
255 count_busy_threads_.Wait(spin_duration_);
256}
257
258ThreadPool::~ThreadPool() {
259 // Send all exit requests upfront so threads can work on them in parallel.
260 for (auto w : threads_) {
261 w->RequestExitAsSoonAsPossible();
262 }
263 for (auto w : threads_) {
264 delete w;
265 }
266}
267
268} // end namespace ruy
269