1 | #pragma once |
2 | |
3 | #include <ATen/ATen.h> |
4 | #include <stdexcept> |
5 | #include <vector> |
6 | |
7 | constexpr auto kNoTimeout = std::chrono::milliseconds(0); |
8 | |
9 | namespace c10d { |
10 | |
11 | constexpr const char* const kSeqNumStoreKey = "SEQ_NUM_STORE_KEY" ; |
12 | |
13 | enum class OpType : std::uint8_t { |
14 | BROADCAST = 0, |
15 | ALLREDUCE = 1, |
16 | ALLREDUCE_COALESCED = 2, |
17 | REDUCE = 3, |
18 | ALLGATHER = 4, |
19 | _ALLGATHER_BASE = 5, |
20 | ALLGATHER_COALESCED = 6, |
21 | GATHER = 7, |
22 | SCATTER = 8, |
23 | REDUCE_SCATTER = 9, |
24 | ALLTOALL_BASE = 10, |
25 | ALLTOALL = 11, |
26 | SEND = 12, |
27 | RECV = 13, |
28 | RECVANYSOURCE = 14, |
29 | BARRIER = 15, |
30 | _REDUCE_SCATTER_BASE = 16, |
31 | UNKNOWN = 100, |
32 | }; |
33 | |
34 | // Converts OpType to human readable string. |
35 | TORCH_API std::string opTypeToString(OpType opType); |
36 | |
37 | // Whether or not an OP is an p2p op (SEND, RECV, RECVANYSOURCE) |
38 | TORCH_API bool isP2POp(OpType opType, bool batchP2P = false); |
39 | |
40 | // Please do not use Work API, it is going away, to be |
41 | // replaced by ivalue::Future. |
42 | // Python binding for this class might change, please do not assume |
43 | // this will be bound using pybind. |
44 | class TORCH_API Work : public torch::CustomClassHolder { |
45 | public: |
46 | Work( |
47 | int rank = -1, |
48 | OpType opType = OpType::UNKNOWN, |
49 | const char* profilingTitle = nullptr, |
50 | const c10::optional<std::vector<at::Tensor>>& inputTensors = |
51 | c10::nullopt); |
52 | |
53 | ~Work() override; |
54 | |
55 | // Checks if request has completed. Non-blocking operation. |
56 | virtual bool isCompleted(); |
57 | |
58 | // Returns if the work completed successfully. |
59 | // If false, the exception function can be called to get details. |
60 | virtual bool isSuccess() const; |
61 | |
62 | // Returns exception if isSuccess() returned false. |
63 | virtual std::exception_ptr exception() const; |
64 | |
65 | // Returns source rank if this objects represents a recv-from-any. |
66 | virtual int sourceRank() const; |
67 | |
68 | // Returns result tensors, if applicable. |
69 | // If work is not supposed to have result, we return empty list. |
70 | virtual std::vector<at::Tensor> result(); |
71 | |
72 | // Ensures that operations on the output tensors that are invoked |
73 | // after this function returns are correctly sequenced after the |
74 | // asynchronous completion of this work. |
75 | // |
76 | // For CUDA tensors, it inserts stream synchronization such that |
77 | // the streams of the caller wait for completion of the |
78 | // asynchronous operations on the destination tensors. |
79 | // |
80 | // For CPU tensors, it is currently a nop. |
81 | // |
82 | // This function should only be used if the caller polls for |
83 | // completion through the `isCompleted` function, it has returned |
84 | // true, and the `isSuccess` function also has returned true. |
85 | // |
86 | virtual void synchronize(); |
87 | |
88 | // Waits until request completes. Blocking operation. |
89 | // Throws if the work completed with an exception. |
90 | // Returns false if the work is aborted. |
91 | // Otherwise, it always returns true, indicating the work is completed. |
92 | // |
93 | // Functionally equivalent to: |
94 | // |
95 | // while (!isCompleted()) { /* nop */ } |
96 | // auto success = isSuccess(); |
97 | // if (!success) { std::rethrow_exception(exception()); } |
98 | // return success; |
99 | // |
100 | virtual bool wait(std::chrono::milliseconds timeout = kNoTimeout); |
101 | |
102 | virtual void abort(); |
103 | |
104 | // Returns a Future object that will be associated with the completion of |
105 | // work. Only NCCL backend is currently supported. |
106 | virtual c10::intrusive_ptr<c10::ivalue::Future> getFuture(); |
107 | |
108 | OpType retrieveOpType(); |
109 | |
110 | static c10::intrusive_ptr<Work> create_from_future( |
111 | c10::intrusive_ptr<c10::ivalue::Future>); |
112 | |
113 | protected: |
114 | // Completes the work object and optionally sets the exception in a |
115 | // thread-safe manner. Notifies all waiting condition variables as well. |
116 | void finish(std::exception_ptr exception = nullptr); |
117 | |
118 | // Similar to finish, but throws an exception if one is already set or |
119 | // provided by the user. |
120 | void finishAndThrow(std::exception_ptr exception); |
121 | |
122 | mutable std::mutex mutex_; |
123 | std::condition_variable cv_; |
124 | bool completed_ = false; |
125 | std::exception_ptr exception_; |
126 | |
127 | // Current rank of the node. |
128 | const int rank_; |
129 | |
130 | // Operation type that this work object refers to. |
131 | OpType opType_; |
132 | |
133 | // When profiling, the callback to record end of operation event. This |
134 | // callback needs to be called when collective operation is complete. |
135 | std::function<void()> recordFunctionEndCallback_; |
136 | }; |
137 | |
138 | } // namespace c10d |
139 | |