1 | #include <torch/csrc/lazy/core/thread_pool.h> |
---|---|
2 | |
3 | #include <c10/util/Logging.h> |
4 | #include <c10/util/irange.h> |
5 | #include <torch/csrc/lazy/core/config.h> |
6 | #include <torch/csrc/lazy/core/metrics.h> |
7 | |
8 | #include <condition_variable> |
9 | #include <deque> |
10 | #include <exception> |
11 | #include <mutex> |
12 | |
13 | namespace torch { |
14 | namespace lazy { |
15 | namespace { |
16 | |
17 | class ThreadPool { |
18 | public: |
19 | explicit ThreadPool(size_t num_threads) { |
20 | threads_.reserve(num_threads); |
21 | for (const auto i : c10::irange(num_threads)) { |
22 | (void)i; // Suppress unused variable warning |
23 | threads_.emplace_back([this]() { Worker(); }); |
24 | } |
25 | } |
26 | |
27 | ~ThreadPool() { |
28 | { |
29 | std::lock_guard<std::mutex> lock(mutex_); |
30 | exiting_ = true; |
31 | cv_.notify_all(); |
32 | } |
33 | for (auto& thread : threads_) { |
34 | thread.join(); |
35 | } |
36 | } |
37 | |
38 | void Schedule(std::function<void()> closure) { |
39 | // If we have more work scheduled than waiting worker threads, just schedule |
40 | // it on a separate thread. This prevents tricky thread-pool-size-deadlocks |
41 | // caused by an undersized thread pool and closures that end up doing sync |
42 | // waits on the pool threads. |
43 | { |
44 | std::unique_lock<std::mutex> lock(mutex_); |
45 | if (work_.size() < waiting_) { |
46 | work_.emplace_back(std::move(closure)); |
47 | lock.unlock(); |
48 | cv_.notify_one(); |
49 | return; |
50 | } |
51 | } |
52 | ScheduleOnThread(std::move(closure)); |
53 | } |
54 | |
55 | private: |
56 | void Worker() { |
57 | while (true) { |
58 | std::function<void()> closure = GetWork(); |
59 | if (closure == nullptr) { |
60 | break; |
61 | } |
62 | try { |
63 | closure(); |
64 | } catch (const std::exception& ex) { |
65 | TORCH_LAZY_COUNTER("ThreadPoolException", 1); |
66 | LOG(ERROR) << "Exception from running thread pool closure: " |
67 | << ex.what(); |
68 | } |
69 | } |
70 | } |
71 | |
72 | void ScheduleOnThread(std::function<void()> closure) { |
73 | std::thread thread(std::move(closure)); |
74 | thread.detach(); |
75 | } |
76 | |
77 | std::function<void()> GetWork() { |
78 | std::unique_lock<std::mutex> lock(mutex_); |
79 | ++waiting_; |
80 | cv_.wait(lock, [this] { return exiting_ || !work_.empty(); }); |
81 | --waiting_; |
82 | if (work_.empty()) { |
83 | return nullptr; |
84 | } |
85 | std::function<void()> closure(std::move(work_.front())); |
86 | work_.pop_front(); |
87 | return closure; |
88 | } |
89 | |
90 | std::vector<std::thread> threads_; |
91 | std::mutex mutex_; |
92 | std::condition_variable cv_; |
93 | bool exiting_ = false; |
94 | std::deque<std::function<void()>> work_; |
95 | size_t waiting_ = 0; |
96 | }; |
97 | |
98 | ThreadPool* GetIoThreadPool() { |
99 | static ThreadPool* pool = |
100 | new ThreadPool(FLAGS_torch_lazy_io_thread_pool_size); |
101 | return pool; |
102 | } |
103 | |
104 | } // namespace |
105 | |
106 | class Completion::Data { |
107 | public: |
108 | void Wait() { |
109 | std::unique_lock<std::mutex> lock(mutex_); |
110 | cv_.wait(lock, [this] { return completed_; }); |
111 | if (exptr_ != nullptr) { |
112 | std::rethrow_exception(exptr_); |
113 | } |
114 | } |
115 | |
116 | static std::function<void()> GetCompleter( |
117 | std::shared_ptr<Data> data, |
118 | std::function<void()> closure) { |
119 | auto closure_wrapper = [closure = std::move(closure), data]() { |
120 | std::exception_ptr exptr; |
121 | try { |
122 | closure(); |
123 | } catch (...) { |
124 | exptr = std::current_exception(); |
125 | } |
126 | data->Complete(exptr); |
127 | }; |
128 | return closure_wrapper; |
129 | } |
130 | |
131 | private: |
132 | void Complete(std::exception_ptr exptr) { |
133 | std::lock_guard<std::mutex> lock(mutex_); |
134 | exptr_ = std::move(exptr); |
135 | completed_ = true; |
136 | cv_.notify_all(); |
137 | } |
138 | |
139 | std::mutex mutex_; |
140 | std::condition_variable cv_; |
141 | bool completed_ = false; |
142 | std::exception_ptr exptr_; |
143 | }; |
144 | |
145 | Completion::Completion(std::shared_ptr<Data> data) : data_(std::move(data)) {} |
146 | |
147 | void Completion::Wait() { |
148 | data_->Wait(); |
149 | } |
150 | |
151 | void ScheduleIoClosure(std::function<void()> closure) { |
152 | GetIoThreadPool()->Schedule(std::move(closure)); |
153 | } |
154 | |
155 | Completion ScheduleIoClosureWithCompletion(std::function<void()> closure) { |
156 | auto data = std::make_shared<Completion::Data>(); |
157 | GetIoThreadPool()->Schedule( |
158 | Completion::Data::GetCompleter(data, std::move(closure))); |
159 | return Completion(std::move(data)); |
160 | } |
161 | |
162 | } // namespace lazy |
163 | } // namespace torch |
164 |