1/**
2 * This file is adapted from PyTorch/XLA
3 * https://github.com/pytorch/xla/blob/master/third_party/xla_client/multi_wait.h
4 */
5
6#pragma once
7
8#include <condition_variable>
9#include <functional>
10#include <memory>
11#include <mutex>
12
13#include <c10/macros/Export.h>
14
15namespace torch {
16namespace lazy {
17
18// Support waiting for a number of tasks to complete.
19class TORCH_API MultiWait {
20 public:
21 explicit MultiWait(size_t count) : count_(count) {}
22
23 // Signal the completion of a single task.
24 void Done();
25
26 // Waits until at least count (passed as constructor value) completions
27 // happened.
28 void Wait();
29
30 // Same as above, but waits up to wait_seconds.
31 void Wait(double wait_seconds);
32
33 // Resets the threshold counter for the MultiWait object. The completed count
34 // is also reset to zero.
35 void Reset(size_t count);
36
37 // Creates a completer functor which signals the mult wait object once func
38 // has completed. Handles exceptions by signaling the multi wait with the
39 // proper status value. This API returns a function which captures a MultiWait
40 // reference, so care must be taken such that the reference remains valid for
41 // the whole lifetime of the returned function.
42 std::function<void()> Completer(std::function<void()> func);
43
44 // Similar as the above API, but with explicit capture of the MultiWait shared
45 // pointer.
46 static std::function<void()> Completer(
47 std::shared_ptr<MultiWait> mwait,
48 std::function<void()> func);
49
50 private:
51 void Complete(const std::function<void()>& func);
52
53 std::mutex mutex_;
54 std::condition_variable cv_;
55 size_t count_ = 0;
56 size_t completed_count_ = 0;
57 std::exception_ptr exptr_;
58};
59
60} // namespace lazy
61} // namespace torch
62