1/* Copyright 2019 Google LLC. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// This file is a fork of gemmlowp's multi_thread_gemm.h, under Apache 2.0
17// license.
18
19#ifndef RUY_RUY_THREAD_POOL_H_
20#define RUY_RUY_THREAD_POOL_H_
21
22#include <vector>
23
24#include "ruy/blocking_counter.h"
25#include "ruy/time.h"
26
27namespace ruy {
28
29// A workload for a thread.
30struct Task {
31 virtual ~Task() {}
32 virtual void Run() = 0;
33};
34
35class Thread;
36
37// A simple pool of threads, that only allows the very
38// specific parallelization pattern that we use here:
39// One thread, which we call the 'main thread', calls Execute, distributing
40// a Task each to N threads, being N-1 'worker threads' and the main thread
41// itself. After the main thread has completed its own Task, it waits for
42// the worker threads to have all completed. That is the only synchronization
43// performed by this ThreadPool.
44//
45// In particular, there is a naive 1:1 mapping of Tasks to threads.
46// This ThreadPool considers it outside of its own scope to try to work
47// with fewer threads than there are Tasks. The idea is that such N:M mappings
48// of tasks to threads can be implemented as a higher-level feature on top of
49// the present low-level 1:1 threadpool. For example, a user might have a
50// Task subclass referencing a shared atomic counter indexing into a vector of
51// finer-granularity subtasks. Different threads would then concurrently
52// increment this atomic counter, getting each their own subtasks to work on.
53// That approach is the one used in ruy's multi-thread matrix multiplication
54// implementation --- see ruy's TrMulTask.
55class ThreadPool {
56 public:
57 ThreadPool() {}
58
59 ~ThreadPool();
60
61 // Executes task_count tasks on task_count threads.
62 // Grows the threadpool as needed to have at least (task_count-1) threads.
63 // The 0-th task is run on the thread on which Execute is called: that
64 // is by definition what we call the "main thread". Synchronization of all
65 // threads is performed before this function returns.
66 //
67 // As explained in the class comment, there is a 1:1 mapping of tasks to
68 // threads. If you need something smarter than that, for instance if you
69 // want to run an unbounded number of tasks on a bounded number of threads,
70 // then you need something higher-level than this ThreadPool, that can
71 // be layered on top of it by appropriately subclassing Tasks.
72 //
73 // TaskType must be a subclass of ruy::Task. That is implicitly guarded by
74 // the static_cast in this inline implementation.
75 template <typename TaskType>
76 void Execute(int task_count, TaskType* tasks) {
77 ExecuteImpl(task_count, sizeof(TaskType), static_cast<Task*>(tasks));
78 }
79
80 void set_spin_milliseconds(float milliseconds) {
81 spin_duration_ = DurationFromMilliseconds(milliseconds);
82 }
83
84 float spin_milliseconds() const {
85 return ToFloatMilliseconds(spin_duration_);
86 }
87
88 private:
89 // Ensures that the pool has at least the given count of threads.
90 // If any new thread has to be created, this function waits for it to
91 // be ready.
92 void CreateThreads(int threads_count);
93
94 // Non-templatized implementation of the public Execute method.
95 // See the inline implementation of Execute for how this is used.
96 void ExecuteImpl(int task_count, int stride, Task* tasks);
97
98 // copy construction disallowed
99 ThreadPool(const ThreadPool&) = delete;
100
101 // The worker threads in this pool. They are owned by the pool:
102 // the pool creates threads and destroys them in its destructor.
103 std::vector<Thread*> threads_;
104
105 // The BlockingCounter used to wait for the threads.
106 BlockingCounter count_busy_threads_;
107
108 // This value was empirically derived with some microbenchmark, we don't have
109 // high confidence in it.
110 //
111 // That this value means that we may be sleeping substantially longer
112 // than a scheduler timeslice's duration is not necessarily surprising. The
113 // idea is to pick up quickly new work after having finished the previous
114 // workload. When it's new work within the same GEMM as the previous work, the
115 // time interval that we might be busy-waiting is very small, so for that
116 // purpose it would be more than enough to sleep for 1 ms.
117 // That is all what we would observe on a GEMM benchmark. However, in a real
118 // application, after having finished a GEMM, we might do unrelated work for
119 // a little while, then start on a new GEMM. In that case the wait interval
120 // may be a little longer. There may also not be another GEMM for a long time,
121 // in which case we'll end up passively waiting below.
122 Duration spin_duration_ = DurationFromMilliseconds(2);
123};
124
125} // namespace ruy
126
127#endif // RUY_RUY_THREAD_POOL_H_
128