1 | /* Copyright 2019 Google LLC. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // This file is a fork of gemmlowp's multi_thread_gemm.h, under Apache 2.0 |
17 | // license. |
18 | |
19 | #ifndef RUY_RUY_THREAD_POOL_H_ |
20 | #define RUY_RUY_THREAD_POOL_H_ |
21 | |
22 | #include <vector> |
23 | |
24 | #include "ruy/blocking_counter.h" |
25 | #include "ruy/time.h" |
26 | |
27 | namespace ruy { |
28 | |
29 | // A workload for a thread. |
30 | struct Task { |
31 | virtual ~Task() {} |
32 | virtual void Run() = 0; |
33 | }; |
34 | |
35 | class Thread; |
36 | |
37 | // A simple pool of threads, that only allows the very |
38 | // specific parallelization pattern that we use here: |
39 | // One thread, which we call the 'main thread', calls Execute, distributing |
40 | // a Task each to N threads, being N-1 'worker threads' and the main thread |
41 | // itself. After the main thread has completed its own Task, it waits for |
42 | // the worker threads to have all completed. That is the only synchronization |
43 | // performed by this ThreadPool. |
44 | // |
45 | // In particular, there is a naive 1:1 mapping of Tasks to threads. |
46 | // This ThreadPool considers it outside of its own scope to try to work |
47 | // with fewer threads than there are Tasks. The idea is that such N:M mappings |
48 | // of tasks to threads can be implemented as a higher-level feature on top of |
49 | // the present low-level 1:1 threadpool. For example, a user might have a |
50 | // Task subclass referencing a shared atomic counter indexing into a vector of |
51 | // finer-granularity subtasks. Different threads would then concurrently |
52 | // increment this atomic counter, getting each their own subtasks to work on. |
53 | // That approach is the one used in ruy's multi-thread matrix multiplication |
54 | // implementation --- see ruy's TrMulTask. |
55 | class ThreadPool { |
56 | public: |
57 | ThreadPool() {} |
58 | |
59 | ~ThreadPool(); |
60 | |
61 | // Executes task_count tasks on task_count threads. |
62 | // Grows the threadpool as needed to have at least (task_count-1) threads. |
63 | // The 0-th task is run on the thread on which Execute is called: that |
64 | // is by definition what we call the "main thread". Synchronization of all |
65 | // threads is performed before this function returns. |
66 | // |
67 | // As explained in the class comment, there is a 1:1 mapping of tasks to |
68 | // threads. If you need something smarter than that, for instance if you |
69 | // want to run an unbounded number of tasks on a bounded number of threads, |
70 | // then you need something higher-level than this ThreadPool, that can |
71 | // be layered on top of it by appropriately subclassing Tasks. |
72 | // |
73 | // TaskType must be a subclass of ruy::Task. That is implicitly guarded by |
74 | // the static_cast in this inline implementation. |
75 | template <typename TaskType> |
76 | void Execute(int task_count, TaskType* tasks) { |
77 | ExecuteImpl(task_count, sizeof(TaskType), static_cast<Task*>(tasks)); |
78 | } |
79 | |
80 | void set_spin_milliseconds(float milliseconds) { |
81 | spin_duration_ = DurationFromMilliseconds(milliseconds); |
82 | } |
83 | |
84 | float spin_milliseconds() const { |
85 | return ToFloatMilliseconds(spin_duration_); |
86 | } |
87 | |
88 | private: |
89 | // Ensures that the pool has at least the given count of threads. |
90 | // If any new thread has to be created, this function waits for it to |
91 | // be ready. |
92 | void CreateThreads(int threads_count); |
93 | |
94 | // Non-templatized implementation of the public Execute method. |
95 | // See the inline implementation of Execute for how this is used. |
96 | void ExecuteImpl(int task_count, int stride, Task* tasks); |
97 | |
98 | // copy construction disallowed |
99 | ThreadPool(const ThreadPool&) = delete; |
100 | |
101 | // The worker threads in this pool. They are owned by the pool: |
102 | // the pool creates threads and destroys them in its destructor. |
103 | std::vector<Thread*> threads_; |
104 | |
105 | // The BlockingCounter used to wait for the threads. |
106 | BlockingCounter count_busy_threads_; |
107 | |
108 | // This value was empirically derived with some microbenchmark, we don't have |
109 | // high confidence in it. |
110 | // |
111 | // That this value means that we may be sleeping substantially longer |
112 | // than a scheduler timeslice's duration is not necessarily surprising. The |
113 | // idea is to pick up quickly new work after having finished the previous |
114 | // workload. When it's new work within the same GEMM as the previous work, the |
115 | // time interval that we might be busy-waiting is very small, so for that |
116 | // purpose it would be more than enough to sleep for 1 ms. |
117 | // That is all what we would observe on a GEMM benchmark. However, in a real |
118 | // application, after having finished a GEMM, we might do unrelated work for |
119 | // a little while, then start on a new GEMM. In that case the wait interval |
120 | // may be a little longer. There may also not be another GEMM for a long time, |
121 | // in which case we'll end up passively waiting below. |
122 | Duration spin_duration_ = DurationFromMilliseconds(2); |
123 | }; |
124 | |
125 | } // namespace ruy |
126 | |
127 | #endif // RUY_RUY_THREAD_POOL_H_ |
128 | |