1 | #include <torch/csrc/lazy/core/multi_wait.h> |
---|---|
2 | |
3 | #include <chrono> |
4 | #include <exception> |
5 | |
6 | namespace torch { |
7 | namespace lazy { |
8 | |
9 | void MultiWait::Done() { |
10 | bool notify = false; |
11 | { |
12 | std::lock_guard<std::mutex> lock(mutex_); |
13 | completed_count_ += 1; |
14 | notify = completed_count_ == count_; |
15 | } |
16 | if (notify) { |
17 | cv_.notify_all(); |
18 | } |
19 | } |
20 | |
21 | void MultiWait::Wait() { |
22 | std::unique_lock<std::mutex> lock(mutex_); |
23 | cv_.wait(lock, [this] { return completed_count_ >= count_; }); |
24 | if (exptr_ != nullptr) { |
25 | std::rethrow_exception(exptr_); |
26 | } |
27 | } |
28 | |
29 | void MultiWait::Wait(double wait_seconds) { |
30 | std::unique_lock<std::mutex> lock(mutex_); |
31 | if (!cv_.wait_for(lock, std::chrono::duration<double>(wait_seconds), [this] { |
32 | return completed_count_ >= count_; |
33 | })) { |
34 | throw std::runtime_error("Timeout"); |
35 | } |
36 | if (exptr_ != nullptr) { |
37 | std::rethrow_exception(exptr_); |
38 | } |
39 | } |
40 | |
41 | void MultiWait::Reset(size_t count) { |
42 | std::lock_guard<std::mutex> lock(mutex_); |
43 | count_ = count; |
44 | completed_count_ = 0; |
45 | exptr_ = nullptr; |
46 | } |
47 | |
48 | std::function<void()> MultiWait::Completer(std::function<void()> func) { |
49 | auto completer = [this, func = std::move(func)]() { Complete(func); }; |
50 | return completer; |
51 | } |
52 | |
53 | std::function<void()> MultiWait::Completer( |
54 | std::shared_ptr<MultiWait> mwait, |
55 | std::function<void()> func) { |
56 | auto completer = [mwait = std::move(mwait), func = std::move(func)]() { |
57 | mwait->Complete(func); |
58 | }; |
59 | return completer; |
60 | } |
61 | |
62 | void MultiWait::Complete(const std::function<void()>& func) { |
63 | try { |
64 | func(); |
65 | } catch (...) { |
66 | std::lock_guard<std::mutex> lock(mutex_); |
67 | exptr_ = std::current_exception(); |
68 | } |
69 | Done(); |
70 | } |
71 | |
72 | } // namespace lazy |
73 | } // namespace torch |
74 |