1#include <c10/core/thread_pool.h>
2
3namespace c10 {
4
5ThreadPool::ThreadPool(
6 int pool_size,
7 int numa_node_id,
8 std::function<void()> init_thread)
9 : threads_(pool_size < 0 ? defaultNumThreads() : pool_size),
10 running_(true),
11 complete_(true),
12 available_(threads_.size()),
13 total_(threads_.size()),
14 numa_node_id_(numa_node_id) {
15 for (std::size_t i = 0; i < threads_.size(); ++i) {
16 threads_[i] = std::thread([this, i, init_thread]() {
17 if (init_thread) {
18 init_thread();
19 }
20 this->main_loop(i);
21 });
22 }
23}
24
25ThreadPool::~ThreadPool() {
26 // Set running flag to false then notify all threads.
27 {
28 std::unique_lock<std::mutex> lock(mutex_);
29 running_ = false;
30 condition_.notify_all();
31 }
32
33 for (auto& t : threads_) {
34 try {
35 t.join();
36 } catch (const std::exception&) {
37 }
38 }
39}
40
41size_t ThreadPool::size() const {
42 return threads_.size();
43}
44
45size_t ThreadPool::numAvailable() const {
46 std::unique_lock<std::mutex> lock(mutex_);
47 return available_;
48}
49
50bool ThreadPool::inThreadPool() const {
51 for (auto& thread : threads_) {
52 if (thread.get_id() == std::this_thread::get_id()) {
53 return true;
54 }
55 }
56 return false;
57}
58
59void ThreadPool::run(std::function<void()> func) {
60 if (threads_.empty()) {
61 throw std::runtime_error("No threads to run a task");
62 }
63 std::unique_lock<std::mutex> lock(mutex_);
64
65 // Set task and signal condition variable so that a worker thread will
66 // wake up and use the task.
67 tasks_.emplace(std::move(func));
68 complete_ = false;
69 condition_.notify_one();
70}
71
72void ThreadPool::waitWorkComplete() {
73 std::unique_lock<std::mutex> lock(mutex_);
74 completed_.wait(lock, [&]() { return complete_; });
75}
76
77void ThreadPool::main_loop(std::size_t index) {
78 std::unique_lock<std::mutex> lock(mutex_);
79 while (running_) {
80 // Wait on condition variable while the task is empty and
81 // the pool is still running.
82 condition_.wait(lock, [&]() { return !tasks_.empty() || !running_; });
83 // If pool is no longer running, break out of loop.
84 if (!running_) {
85 break;
86 }
87
88 // Copy task locally and remove from the queue. This is
89 // done within its own scope so that the task object is
90 // destructed immediately after running the task. This is
91 // useful in the event that the function contains
92 // shared_ptr arguments bound via bind.
93 {
94 task_element_t tasks = std::move(tasks_.front());
95 tasks_.pop();
96 // Decrement count, indicating thread is no longer available.
97 --available_;
98
99 lock.unlock();
100
101 // Run the task.
102 try {
103 if (tasks.run_with_id) {
104 tasks.with_id(index);
105 } else {
106 tasks.no_id();
107 }
108 } catch (const std::exception& e) {
109 LOG(ERROR) << "Exception in thread pool task: " << e.what();
110 } catch (...) {
111 LOG(ERROR) << "Exception in thread pool task: unknown";
112 }
113
114 // Destruct tasks before taking the lock. As tasks
115 // are user provided std::function, they can run
116 // arbitrary code during destruction, including code
117 // that can reentrantly call into ThreadPool (which would
118 // cause a deadlock if we were holding the lock).
119 }
120
121 // Update status of empty, maybe
122 // Need to recover the lock first
123 lock.lock();
124
125 // Increment count, indicating thread is available.
126 ++available_;
127 if (tasks_.empty() && available_ == total_) {
128 complete_ = true;
129 completed_.notify_one();
130 }
131
132 // Deliberately hold the lock on the backedge, so this thread has an
133 // opportunity to acquire a new task before another thread acquires
134 // the lock.
135 } // while running_
136}
137
138C10_DEFINE_SHARED_REGISTRY(
139 ThreadPoolRegistry,
140 TaskThreadPoolBase,
141 int,
142 int,
143 bool);
144} // namespace c10
145