1#pragma once
2
3#include <atomic>
4#include <condition_variable>
5#include <functional>
6#include <mutex>
7#include <queue>
8#include <thread>
9#include <utility>
10
11#include <c10/util/numa.h>
12#include <c10/util/thread_name.h>
13
14C10_CLANG_DIAGNOSTIC_PUSH()
15#if C10_CLANG_HAS_WARNING("-Wshorten-64-to-32")
16C10_CLANG_DIAGNOSTIC_IGNORE("-Wshorten-64-to-32")
17#endif
18
19namespace c10 {
20
21// TODO: move this to C10 and make it C10_API
22class C10_API TaskThreadPoolBase {
23 public:
24 virtual void run(std::function<void()> func) = 0;
25
26 virtual size_t size() const = 0;
27
28 /**
29 * The number of available (i.e. idle) threads in this thread pool.
30 */
31 virtual size_t numAvailable() const = 0;
32
33 /**
34 * Check if the current thread is from the thread pool.
35 */
36 virtual bool inThreadPool() const = 0;
37
38 virtual ~TaskThreadPoolBase() noexcept = default;
39
40 static size_t defaultNumThreads() {
41 auto num_threads = std::thread::hardware_concurrency();
42#if defined(_M_X64) || defined(__x86_64__)
43 num_threads /= 2;
44#endif
45 return num_threads;
46 }
47};
48
49class C10_API ThreadPool : public c10::TaskThreadPoolBase {
50 protected:
51 struct task_element_t {
52 bool run_with_id;
53 const std::function<void()> no_id;
54 const std::function<void(std::size_t)> with_id;
55
56 explicit task_element_t(std::function<void()> f)
57 : run_with_id(false), no_id(std::move(f)), with_id(nullptr) {}
58 explicit task_element_t(std::function<void(std::size_t)> f)
59 : run_with_id(true), no_id(nullptr), with_id(std::move(f)) {}
60 };
61
62 std::queue<task_element_t> tasks_;
63 std::vector<std::thread> threads_;
64 mutable std::mutex mutex_;
65 std::condition_variable condition_;
66 std::condition_variable completed_;
67 std::atomic_bool running_;
68 bool complete_;
69 std::size_t available_;
70 std::size_t total_;
71 int numa_node_id_;
72
73 public:
74 ThreadPool() = delete;
75
76 explicit ThreadPool(
77 int pool_size,
78 int numa_node_id = -1,
79 std::function<void()> init_thread = nullptr);
80
81 ~ThreadPool() override;
82
83 size_t size() const override;
84
85 size_t numAvailable() const override;
86
87 bool inThreadPool() const override;
88
89 void run(std::function<void()> func) override;
90
91 template <typename Task>
92 void runTaskWithID(Task task) {
93 std::unique_lock<std::mutex> lock(mutex_);
94
95 // Set task and signal condition variable so that a worker thread will
96 // wake up and use the task.
97 tasks_.emplace(static_cast<std::function<void(std::size_t)>>(task));
98 complete_ = false;
99 condition_.notify_one();
100 }
101
102 /// @brief Wait for queue to be empty
103 void waitWorkComplete();
104
105 private:
106 // @brief Entry point for pool threads.
107 void main_loop(std::size_t index);
108};
109
110class C10_API TaskThreadPool : public c10::ThreadPool {
111 public:
112 explicit TaskThreadPool(std::size_t pool_size, int numa_node_id = -1)
113 : ThreadPool(pool_size, numa_node_id, [numa_node_id]() {
114 setThreadName("CaffeTaskThread");
115 NUMABind(numa_node_id);
116 }) {}
117};
118
119C10_DECLARE_SHARED_REGISTRY(
120 ThreadPoolRegistry,
121 TaskThreadPoolBase,
122 int,
123 int,
124 bool);
125
126} // namespace c10
127
128C10_CLANG_DIAGNOSTIC_POP()
129