1#include <torch/csrc/lazy/core/multi_wait.h>
2
3#include <chrono>
4#include <exception>
5
6namespace torch {
7namespace lazy {
8
9void MultiWait::Done() {
10 bool notify = false;
11 {
12 std::lock_guard<std::mutex> lock(mutex_);
13 completed_count_ += 1;
14 notify = completed_count_ == count_;
15 }
16 if (notify) {
17 cv_.notify_all();
18 }
19}
20
21void MultiWait::Wait() {
22 std::unique_lock<std::mutex> lock(mutex_);
23 cv_.wait(lock, [this] { return completed_count_ >= count_; });
24 if (exptr_ != nullptr) {
25 std::rethrow_exception(exptr_);
26 }
27}
28
29void MultiWait::Wait(double wait_seconds) {
30 std::unique_lock<std::mutex> lock(mutex_);
31 if (!cv_.wait_for(lock, std::chrono::duration<double>(wait_seconds), [this] {
32 return completed_count_ >= count_;
33 })) {
34 throw std::runtime_error("Timeout");
35 }
36 if (exptr_ != nullptr) {
37 std::rethrow_exception(exptr_);
38 }
39}
40
41void MultiWait::Reset(size_t count) {
42 std::lock_guard<std::mutex> lock(mutex_);
43 count_ = count;
44 completed_count_ = 0;
45 exptr_ = nullptr;
46}
47
48std::function<void()> MultiWait::Completer(std::function<void()> func) {
49 auto completer = [this, func = std::move(func)]() { Complete(func); };
50 return completer;
51}
52
53std::function<void()> MultiWait::Completer(
54 std::shared_ptr<MultiWait> mwait,
55 std::function<void()> func) {
56 auto completer = [mwait = std::move(mwait), func = std::move(func)]() {
57 mwait->Complete(func);
58 };
59 return completer;
60}
61
62void MultiWait::Complete(const std::function<void()>& func) {
63 try {
64 func();
65 } catch (...) {
66 std::lock_guard<std::mutex> lock(mutex_);
67 exptr_ = std::current_exception();
68 }
69 Done();
70}
71
72} // namespace lazy
73} // namespace torch
74