1 | #include <c10/core/thread_pool.h> |
2 | |
3 | namespace c10 { |
4 | |
5 | ThreadPool::ThreadPool( |
6 | int pool_size, |
7 | int numa_node_id, |
8 | std::function<void()> init_thread) |
9 | : threads_(pool_size < 0 ? defaultNumThreads() : pool_size), |
10 | running_(true), |
11 | complete_(true), |
12 | available_(threads_.size()), |
13 | total_(threads_.size()), |
14 | numa_node_id_(numa_node_id) { |
15 | for (std::size_t i = 0; i < threads_.size(); ++i) { |
16 | threads_[i] = std::thread([this, i, init_thread]() { |
17 | if (init_thread) { |
18 | init_thread(); |
19 | } |
20 | this->main_loop(i); |
21 | }); |
22 | } |
23 | } |
24 | |
25 | ThreadPool::~ThreadPool() { |
26 | // Set running flag to false then notify all threads. |
27 | { |
28 | std::unique_lock<std::mutex> lock(mutex_); |
29 | running_ = false; |
30 | condition_.notify_all(); |
31 | } |
32 | |
33 | for (auto& t : threads_) { |
34 | try { |
35 | t.join(); |
36 | } catch (const std::exception&) { |
37 | } |
38 | } |
39 | } |
40 | |
41 | size_t ThreadPool::size() const { |
42 | return threads_.size(); |
43 | } |
44 | |
45 | size_t ThreadPool::numAvailable() const { |
46 | std::unique_lock<std::mutex> lock(mutex_); |
47 | return available_; |
48 | } |
49 | |
50 | bool ThreadPool::inThreadPool() const { |
51 | for (auto& thread : threads_) { |
52 | if (thread.get_id() == std::this_thread::get_id()) { |
53 | return true; |
54 | } |
55 | } |
56 | return false; |
57 | } |
58 | |
59 | void ThreadPool::run(std::function<void()> func) { |
60 | if (threads_.empty()) { |
61 | throw std::runtime_error("No threads to run a task" ); |
62 | } |
63 | std::unique_lock<std::mutex> lock(mutex_); |
64 | |
65 | // Set task and signal condition variable so that a worker thread will |
66 | // wake up and use the task. |
67 | tasks_.emplace(std::move(func)); |
68 | complete_ = false; |
69 | condition_.notify_one(); |
70 | } |
71 | |
72 | void ThreadPool::waitWorkComplete() { |
73 | std::unique_lock<std::mutex> lock(mutex_); |
74 | completed_.wait(lock, [&]() { return complete_; }); |
75 | } |
76 | |
77 | void ThreadPool::main_loop(std::size_t index) { |
78 | std::unique_lock<std::mutex> lock(mutex_); |
79 | while (running_) { |
80 | // Wait on condition variable while the task is empty and |
81 | // the pool is still running. |
82 | condition_.wait(lock, [&]() { return !tasks_.empty() || !running_; }); |
83 | // If pool is no longer running, break out of loop. |
84 | if (!running_) { |
85 | break; |
86 | } |
87 | |
88 | // Copy task locally and remove from the queue. This is |
89 | // done within its own scope so that the task object is |
90 | // destructed immediately after running the task. This is |
91 | // useful in the event that the function contains |
92 | // shared_ptr arguments bound via bind. |
93 | { |
94 | task_element_t tasks = std::move(tasks_.front()); |
95 | tasks_.pop(); |
96 | // Decrement count, indicating thread is no longer available. |
97 | --available_; |
98 | |
99 | lock.unlock(); |
100 | |
101 | // Run the task. |
102 | try { |
103 | if (tasks.run_with_id) { |
104 | tasks.with_id(index); |
105 | } else { |
106 | tasks.no_id(); |
107 | } |
108 | } catch (const std::exception& e) { |
109 | LOG(ERROR) << "Exception in thread pool task: " << e.what(); |
110 | } catch (...) { |
111 | LOG(ERROR) << "Exception in thread pool task: unknown" ; |
112 | } |
113 | |
114 | // Destruct tasks before taking the lock. As tasks |
115 | // are user provided std::function, they can run |
116 | // arbitrary code during destruction, including code |
117 | // that can reentrantly call into ThreadPool (which would |
118 | // cause a deadlock if we were holding the lock). |
119 | } |
120 | |
121 | // Update status of empty, maybe |
122 | // Need to recover the lock first |
123 | lock.lock(); |
124 | |
125 | // Increment count, indicating thread is available. |
126 | ++available_; |
127 | if (tasks_.empty() && available_ == total_) { |
128 | complete_ = true; |
129 | completed_.notify_one(); |
130 | } |
131 | |
132 | // Deliberately hold the lock on the backedge, so this thread has an |
133 | // opportunity to acquire a new task before another thread acquires |
134 | // the lock. |
135 | } // while running_ |
136 | } |
137 | |
138 | C10_DEFINE_SHARED_REGISTRY( |
139 | ThreadPoolRegistry, |
140 | TaskThreadPoolBase, |
141 | int, |
142 | int, |
143 | bool); |
144 | } // namespace c10 |
145 | |