1 | /** |
2 | * This file is adapted from PyTorch/XLA |
3 | * https://github.com/pytorch/xla/blob/master/third_party/xla_client/multi_wait.h |
4 | */ |
5 | |
6 | #pragma once |
7 | |
8 | #include <condition_variable> |
9 | #include <functional> |
10 | #include <memory> |
11 | #include <mutex> |
12 | |
13 | #include <c10/macros/Export.h> |
14 | |
15 | namespace torch { |
16 | namespace lazy { |
17 | |
18 | // Support waiting for a number of tasks to complete. |
19 | class TORCH_API MultiWait { |
20 | public: |
21 | explicit MultiWait(size_t count) : count_(count) {} |
22 | |
23 | // Signal the completion of a single task. |
24 | void Done(); |
25 | |
26 | // Waits until at least count (passed as constructor value) completions |
27 | // happened. |
28 | void Wait(); |
29 | |
30 | // Same as above, but waits up to wait_seconds. |
31 | void Wait(double wait_seconds); |
32 | |
33 | // Resets the threshold counter for the MultiWait object. The completed count |
34 | // is also reset to zero. |
35 | void Reset(size_t count); |
36 | |
37 | // Creates a completer functor which signals the mult wait object once func |
38 | // has completed. Handles exceptions by signaling the multi wait with the |
39 | // proper status value. This API returns a function which captures a MultiWait |
40 | // reference, so care must be taken such that the reference remains valid for |
41 | // the whole lifetime of the returned function. |
42 | std::function<void()> Completer(std::function<void()> func); |
43 | |
44 | // Similar as the above API, but with explicit capture of the MultiWait shared |
45 | // pointer. |
46 | static std::function<void()> Completer( |
47 | std::shared_ptr<MultiWait> mwait, |
48 | std::function<void()> func); |
49 | |
50 | private: |
51 | void Complete(const std::function<void()>& func); |
52 | |
53 | std::mutex mutex_; |
54 | std::condition_variable cv_; |
55 | size_t count_ = 0; |
56 | size_t completed_count_ = 0; |
57 | std::exception_ptr exptr_; |
58 | }; |
59 | |
60 | } // namespace lazy |
61 | } // namespace torch |
62 | |