1#include <torch/csrc/lazy/core/thread_pool.h>
2
3#include <c10/util/Logging.h>
4#include <c10/util/irange.h>
5#include <torch/csrc/lazy/core/config.h>
6#include <torch/csrc/lazy/core/metrics.h>
7
8#include <condition_variable>
9#include <deque>
10#include <exception>
11#include <mutex>
12
13namespace torch {
14namespace lazy {
15namespace {
16
17class ThreadPool {
18 public:
19 explicit ThreadPool(size_t num_threads) {
20 threads_.reserve(num_threads);
21 for (const auto i : c10::irange(num_threads)) {
22 (void)i; // Suppress unused variable warning
23 threads_.emplace_back([this]() { Worker(); });
24 }
25 }
26
27 ~ThreadPool() {
28 {
29 std::lock_guard<std::mutex> lock(mutex_);
30 exiting_ = true;
31 cv_.notify_all();
32 }
33 for (auto& thread : threads_) {
34 thread.join();
35 }
36 }
37
38 void Schedule(std::function<void()> closure) {
39 // If we have more work scheduled than waiting worker threads, just schedule
40 // it on a separate thread. This prevents tricky thread-pool-size-deadlocks
41 // caused by an undersized thread pool and closures that end up doing sync
42 // waits on the pool threads.
43 {
44 std::unique_lock<std::mutex> lock(mutex_);
45 if (work_.size() < waiting_) {
46 work_.emplace_back(std::move(closure));
47 lock.unlock();
48 cv_.notify_one();
49 return;
50 }
51 }
52 ScheduleOnThread(std::move(closure));
53 }
54
55 private:
56 void Worker() {
57 while (true) {
58 std::function<void()> closure = GetWork();
59 if (closure == nullptr) {
60 break;
61 }
62 try {
63 closure();
64 } catch (const std::exception& ex) {
65 TORCH_LAZY_COUNTER("ThreadPoolException", 1);
66 LOG(ERROR) << "Exception from running thread pool closure: "
67 << ex.what();
68 }
69 }
70 }
71
72 void ScheduleOnThread(std::function<void()> closure) {
73 std::thread thread(std::move(closure));
74 thread.detach();
75 }
76
77 std::function<void()> GetWork() {
78 std::unique_lock<std::mutex> lock(mutex_);
79 ++waiting_;
80 cv_.wait(lock, [this] { return exiting_ || !work_.empty(); });
81 --waiting_;
82 if (work_.empty()) {
83 return nullptr;
84 }
85 std::function<void()> closure(std::move(work_.front()));
86 work_.pop_front();
87 return closure;
88 }
89
90 std::vector<std::thread> threads_;
91 std::mutex mutex_;
92 std::condition_variable cv_;
93 bool exiting_ = false;
94 std::deque<std::function<void()>> work_;
95 size_t waiting_ = 0;
96};
97
98ThreadPool* GetIoThreadPool() {
99 static ThreadPool* pool =
100 new ThreadPool(FLAGS_torch_lazy_io_thread_pool_size);
101 return pool;
102}
103
104} // namespace
105
106class Completion::Data {
107 public:
108 void Wait() {
109 std::unique_lock<std::mutex> lock(mutex_);
110 cv_.wait(lock, [this] { return completed_; });
111 if (exptr_ != nullptr) {
112 std::rethrow_exception(exptr_);
113 }
114 }
115
116 static std::function<void()> GetCompleter(
117 std::shared_ptr<Data> data,
118 std::function<void()> closure) {
119 auto closure_wrapper = [closure = std::move(closure), data]() {
120 std::exception_ptr exptr;
121 try {
122 closure();
123 } catch (...) {
124 exptr = std::current_exception();
125 }
126 data->Complete(exptr);
127 };
128 return closure_wrapper;
129 }
130
131 private:
132 void Complete(std::exception_ptr exptr) {
133 std::lock_guard<std::mutex> lock(mutex_);
134 exptr_ = std::move(exptr);
135 completed_ = true;
136 cv_.notify_all();
137 }
138
139 std::mutex mutex_;
140 std::condition_variable cv_;
141 bool completed_ = false;
142 std::exception_ptr exptr_;
143};
144
145Completion::Completion(std::shared_ptr<Data> data) : data_(std::move(data)) {}
146
147void Completion::Wait() {
148 data_->Wait();
149}
150
151void ScheduleIoClosure(std::function<void()> closure) {
152 GetIoThreadPool()->Schedule(std::move(closure));
153}
154
155Completion ScheduleIoClosureWithCompletion(std::function<void()> closure) {
156 auto data = std::make_shared<Completion::Data>();
157 GetIoThreadPool()->Schedule(
158 Completion::Data::GetCompleter(data, std::move(closure)));
159 return Completion(std::move(data));
160}
161
162} // namespace lazy
163} // namespace torch
164