1 | #pragma once |
2 | |
3 | #include <atomic> |
4 | #include <condition_variable> |
5 | #include <functional> |
6 | #include <mutex> |
7 | #include <queue> |
8 | #include <thread> |
9 | #include <utility> |
10 | |
11 | #include <c10/util/numa.h> |
12 | #include <c10/util/thread_name.h> |
13 | |
14 | C10_CLANG_DIAGNOSTIC_PUSH() |
15 | #if C10_CLANG_HAS_WARNING("-Wshorten-64-to-32") |
16 | C10_CLANG_DIAGNOSTIC_IGNORE("-Wshorten-64-to-32" ) |
17 | #endif |
18 | |
19 | namespace c10 { |
20 | |
21 | // TODO: move this to C10 and make it C10_API |
22 | class C10_API TaskThreadPoolBase { |
23 | public: |
24 | virtual void run(std::function<void()> func) = 0; |
25 | |
26 | virtual size_t size() const = 0; |
27 | |
28 | /** |
29 | * The number of available (i.e. idle) threads in this thread pool. |
30 | */ |
31 | virtual size_t numAvailable() const = 0; |
32 | |
33 | /** |
34 | * Check if the current thread is from the thread pool. |
35 | */ |
36 | virtual bool inThreadPool() const = 0; |
37 | |
38 | virtual ~TaskThreadPoolBase() noexcept = default; |
39 | |
40 | static size_t defaultNumThreads() { |
41 | auto num_threads = std::thread::hardware_concurrency(); |
42 | #if defined(_M_X64) || defined(__x86_64__) |
43 | num_threads /= 2; |
44 | #endif |
45 | return num_threads; |
46 | } |
47 | }; |
48 | |
49 | class C10_API ThreadPool : public c10::TaskThreadPoolBase { |
50 | protected: |
51 | struct task_element_t { |
52 | bool run_with_id; |
53 | const std::function<void()> no_id; |
54 | const std::function<void(std::size_t)> with_id; |
55 | |
56 | explicit task_element_t(std::function<void()> f) |
57 | : run_with_id(false), no_id(std::move(f)), with_id(nullptr) {} |
58 | explicit task_element_t(std::function<void(std::size_t)> f) |
59 | : run_with_id(true), no_id(nullptr), with_id(std::move(f)) {} |
60 | }; |
61 | |
62 | std::queue<task_element_t> tasks_; |
63 | std::vector<std::thread> threads_; |
64 | mutable std::mutex mutex_; |
65 | std::condition_variable condition_; |
66 | std::condition_variable completed_; |
67 | std::atomic_bool running_; |
68 | bool complete_; |
69 | std::size_t available_; |
70 | std::size_t total_; |
71 | int numa_node_id_; |
72 | |
73 | public: |
74 | ThreadPool() = delete; |
75 | |
76 | explicit ThreadPool( |
77 | int pool_size, |
78 | int numa_node_id = -1, |
79 | std::function<void()> init_thread = nullptr); |
80 | |
81 | ~ThreadPool() override; |
82 | |
83 | size_t size() const override; |
84 | |
85 | size_t numAvailable() const override; |
86 | |
87 | bool inThreadPool() const override; |
88 | |
89 | void run(std::function<void()> func) override; |
90 | |
91 | template <typename Task> |
92 | void runTaskWithID(Task task) { |
93 | std::unique_lock<std::mutex> lock(mutex_); |
94 | |
95 | // Set task and signal condition variable so that a worker thread will |
96 | // wake up and use the task. |
97 | tasks_.emplace(static_cast<std::function<void(std::size_t)>>(task)); |
98 | complete_ = false; |
99 | condition_.notify_one(); |
100 | } |
101 | |
102 | /// @brief Wait for queue to be empty |
103 | void waitWorkComplete(); |
104 | |
105 | private: |
106 | // @brief Entry point for pool threads. |
107 | void main_loop(std::size_t index); |
108 | }; |
109 | |
110 | class C10_API TaskThreadPool : public c10::ThreadPool { |
111 | public: |
112 | explicit TaskThreadPool(std::size_t pool_size, int numa_node_id = -1) |
113 | : ThreadPool(pool_size, numa_node_id, [numa_node_id]() { |
114 | setThreadName("CaffeTaskThread" ); |
115 | NUMABind(numa_node_id); |
116 | }) {} |
117 | }; |
118 | |
119 | C10_DECLARE_SHARED_REGISTRY( |
120 | ThreadPoolRegistry, |
121 | TaskThreadPoolBase, |
122 | int, |
123 | int, |
124 | bool); |
125 | |
126 | } // namespace c10 |
127 | |
128 | C10_CLANG_DIAGNOSTIC_POP() |
129 | |