1 | #pragma once |
2 | |
3 | #ifdef USE_C10D_UCC |
4 | |
5 | #include <torch/csrc/distributed/c10d/UCCUtils.hpp> |
6 | |
7 | #include <exception> |
8 | #include <memory> |
9 | #include <mutex> |
10 | #include <queue> |
11 | #include <thread> |
12 | #include <vector> |
13 | |
14 | #include <torch/csrc/distributed/c10d/Backend.hpp> |
15 | #include <torch/csrc/distributed/c10d/Store.hpp> |
16 | #include <torch/csrc/distributed/c10d/Types.hpp> |
17 | #include <torch/csrc/distributed/c10d/Utils.hpp> |
18 | #ifdef USE_CUDA |
19 | #include <ATen/cuda/CUDAEvent.h> |
20 | #include <c10/cuda/CUDAStream.h> |
21 | #endif |
22 | |
23 | namespace c10d { |
24 | |
25 | #define TORCH_UCC_DEVICE_NOT_SET -2 |
26 | |
27 | #ifdef USE_CUDA |
28 | #define SAVE_TENSORS(_TENSORS, _DATA) \ |
29 | do { \ |
30 | if ((_TENSORS)[0].device().is_cuda()) { \ |
31 | for (const auto i : c10::irange((_TENSORS).size())) { \ |
32 | c10::cuda::CUDACachingAllocator::recordStream( \ |
33 | (_TENSORS)[i].storage().data_ptr(), (*stream)); \ |
34 | } \ |
35 | } else { \ |
36 | (_DATA) = (_TENSORS); \ |
37 | } \ |
38 | } while (0) |
39 | |
40 | #else |
41 | #define SAVE_TENSORS(_TENSORS, _DATA) (_DATA) = (_TENSORS); |
42 | #endif |
43 | |
44 | constexpr const char* UCC_BACKEND_NAME = "ucc" ; |
45 | |
46 | struct event_pool_t { |
47 | #ifdef USE_CUDA |
48 | std::queue<std::unique_ptr<at::cuda::CUDAEvent>> event_pool; |
49 | #endif |
50 | std::mutex event_pool_mutex; |
51 | }; |
52 | |
53 | class Comm; |
54 | |
55 | // UCC does not support multiple CUDA devices per process. |
56 | class TORCH_API ProcessGroupUCC : public Backend { |
57 | private: |
58 | void set_timeout(ucc_coll_args_t& args); |
59 | |
60 | public: |
61 | class WorkData { |
62 | public: |
63 | std::vector<at::Tensor> src; |
64 | std::vector<at::Tensor> dst; |
65 | std::vector<at::Tensor> flat; |
66 | WorkData() {} |
67 | virtual ~WorkData() = default; |
68 | }; |
69 | class AlltoallWorkData : public WorkData { |
70 | public: |
71 | AlltoallWorkData(int size) |
72 | : send_lengths(size), |
73 | send_offsets(size), |
74 | recv_lengths(size), |
75 | recv_offsets(size) {} |
76 | std::vector<uint64_t> send_lengths; |
77 | std::vector<uint64_t> send_offsets; |
78 | std::vector<uint64_t> recv_lengths; |
79 | std::vector<uint64_t> recv_offsets; |
80 | }; |
81 | |
82 | class AllgathervWorkData : public WorkData { |
83 | public: |
84 | AllgathervWorkData(int size) : recv_lengths(size), recv_offsets(size) {} |
85 | std::vector<uint64_t> recv_lengths; |
86 | std::vector<uint64_t> recv_offsets; |
87 | }; |
88 | |
89 | class ScattervWorkData : public WorkData { |
90 | public: |
91 | ScattervWorkData(int size) : send_lengths(size), send_offsets(size) {} |
92 | std::vector<uint64_t> send_lengths; |
93 | std::vector<uint64_t> send_offsets; |
94 | }; |
95 | |
96 | class ProgressEntry { |
97 | friend class ProcessGroupUCC; |
98 | friend class Comm; |
99 | |
100 | public: |
101 | ProgressEntry(CommBase* comm, ucc_coll_req_h request) |
102 | : status_(UCC_INPROGRESS), comm_(comm), request_(request) {} |
103 | // Finalizes UCC status or exception of collective request. |
104 | void finalize(std::exception_ptr eptr = nullptr); |
105 | ucc_status_t status_; |
106 | CommBase* comm_; |
107 | ucc_coll_req_h request_; |
108 | std::unique_ptr<WorkData> data; |
109 | c10::intrusive_ptr<c10::ivalue::Future> future_; |
110 | std::exception_ptr eptr_; |
111 | }; |
112 | |
113 | class WorkUCC : public Work { |
114 | friend class ProcessGroupUCC; |
115 | friend class Comm; |
116 | |
117 | public: |
118 | WorkUCC( |
119 | OpType opType, |
120 | uint64_t seq, |
121 | const char* prof_title, |
122 | const c10::optional<std::vector<at::Tensor>>& inputs, |
123 | const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger) |
124 | : Work(-1, opType, prof_title, inputs), logger_(logger), seq_(seq) {} |
125 | ~WorkUCC(); |
126 | void setException(); |
127 | void setAndThrowException(); |
128 | bool isCompleted() override; |
129 | bool isSuccess() const override; |
130 | bool wait(std::chrono::milliseconds timeout = kUnsetTimeout) override; |
131 | c10::intrusive_ptr<c10::ivalue::Future> getFuture() override; |
132 | std::vector<at::Tensor> result() override; |
133 | int sourceRank() const override; |
134 | #ifdef USE_CUDA |
135 | std::unique_ptr<at::cuda::CUDAEvent> fence = nullptr; |
136 | event_pool_t* ep = nullptr; |
137 | #endif |
138 | int sourceRank_; |
139 | |
140 | protected: |
141 | std::shared_ptr<ProgressEntry> entry_; |
142 | c10::intrusive_ptr<ProcessGroupUCCLogger> logger_; |
143 | uint64_t seq_; |
144 | |
145 | private: |
146 | // The future returned by getFuture. |
147 | c10::intrusive_ptr<at::ivalue::Future> future_; |
148 | // Store a reference to collective's outputs, used by result |
149 | std::shared_ptr<std::vector<at::Tensor>> outputs_; |
150 | }; |
151 | |
152 | explicit ProcessGroupUCC( |
153 | const c10::intrusive_ptr<Store>& store, |
154 | int rank = -1, |
155 | int size = -1, |
156 | std::chrono::duration<float> timeout = kBackendDefaultTimeout); |
157 | |
158 | void initComm(c10::Device dev); |
159 | |
160 | ~ProcessGroupUCC() override; |
161 | |
162 | const std::string getBackendName() const override { |
163 | return std::string(UCC_BACKEND_NAME); |
164 | } |
165 | |
166 | #ifdef USE_CUDA |
167 | std::unique_ptr<at::cuda::CUDAEvent> getPooledEvent(); |
168 | #endif |
169 | |
170 | // Performs a health check by initializing dummy UCC & UCX communicators and |
171 | // then destroying them. This will help indicate and signal any |
172 | // UCC/UCX-related issues prior to the first collective. The actual |
173 | // initialization and subsequent destruction is ran on a separate thread and |
174 | // the main thread is signalled about timeouts/errors to report to the |
175 | // application. |
176 | void runHealthCheck(); |
177 | |
178 | template <typename PreProcess, typename PostProcess> |
179 | c10::intrusive_ptr<Work> collective_post( |
180 | OpType opType, |
181 | PreProcess preproc, |
182 | PostProcess postproc, |
183 | ucc_coll_args_t& coll, |
184 | std::unique_ptr<ProcessGroupUCC::WorkData> data, |
185 | c10::Device dev, |
186 | std::vector<at::Tensor>& inputTensors, |
187 | std::vector<at::Tensor>& outputTensors, |
188 | const char* prof_title); |
189 | |
190 | c10::intrusive_ptr<Work> broadcast( |
191 | std::vector<at::Tensor>& data, |
192 | const BroadcastOptions& opts = BroadcastOptions()) override; |
193 | |
194 | c10::intrusive_ptr<Work> allreduce( |
195 | std::vector<at::Tensor>& tensors, |
196 | const AllreduceOptions& opts = AllreduceOptions()) override; |
197 | |
198 | c10::intrusive_ptr<Work> allreduce_coalesced( |
199 | std::vector<at::Tensor>& tensors, |
200 | const AllreduceCoalescedOptions& opts = |
201 | AllreduceCoalescedOptions()) override; |
202 | |
203 | c10::intrusive_ptr<Work> reduce( |
204 | std::vector<at::Tensor>& tensors, |
205 | const ReduceOptions& opts = ReduceOptions()) override; |
206 | |
207 | c10::intrusive_ptr<Work> allgather( |
208 | std::vector<std::vector<at::Tensor>>& outputTensors, |
209 | std::vector<at::Tensor>& inputTensors, |
210 | const AllgatherOptions& opts = AllgatherOptions()) override; |
211 | |
212 | c10::intrusive_ptr<Work> _allgather_base( |
213 | at::Tensor& outputBuffer, |
214 | at::Tensor& inputBuffer, |
215 | const AllgatherOptions& opts = AllgatherOptions()) override; |
216 | |
217 | c10::intrusive_ptr<Work> barrier( |
218 | const BarrierOptions& opts = BarrierOptions()) override; |
219 | |
220 | c10::intrusive_ptr<Work> gather( |
221 | std::vector<std::vector<at::Tensor>>& outputTensors, |
222 | std::vector<at::Tensor>& inputTensors, |
223 | const GatherOptions& opts = GatherOptions()) override; |
224 | |
225 | c10::intrusive_ptr<Work> scatter( |
226 | std::vector<at::Tensor>& outputTensors, |
227 | std::vector<std::vector<at::Tensor>>& inputTensors, |
228 | const ScatterOptions& opts = ScatterOptions()) override; |
229 | |
230 | c10::intrusive_ptr<Work> reduce_scatter( |
231 | std::vector<at::Tensor>& outputTensors, |
232 | std::vector<std::vector<at::Tensor>>& inputTensors, |
233 | const ReduceScatterOptions& opts = ReduceScatterOptions()) override; |
234 | |
235 | c10::intrusive_ptr<Work> alltoall_base( |
236 | at::Tensor& outputTensor, |
237 | at::Tensor& inputTensor, |
238 | std::vector<int64_t>& outputSplitSizes, |
239 | std::vector<int64_t>& inputSplitSizes, |
240 | const AllToAllOptions& opts = AllToAllOptions()) override; |
241 | |
242 | c10::intrusive_ptr<Work> alltoall( |
243 | std::vector<at::Tensor>& outputTensors, |
244 | std::vector<at::Tensor>& inputTensors, |
245 | const AllToAllOptions& opts = AllToAllOptions()) override; |
246 | |
247 | c10::intrusive_ptr<Work> send( |
248 | std::vector<at::Tensor>& tensors, |
249 | int dstRank, |
250 | int tag) override; |
251 | |
252 | c10::intrusive_ptr<Work> recv( |
253 | std::vector<at::Tensor>& tensors, |
254 | int srcRank, |
255 | int tag) override; |
256 | |
257 | // Counting for the sequential number of UCC collective_post call. |
258 | uint64_t seq_{0}; |
259 | |
260 | // Agrees on an initial sequence number for the whole group by having rank 0 |
261 | // create it and broadcast it to other ranks using the store. |
262 | void setSequenceNumberForGroup() override; |
263 | |
264 | // Retrieves the current sequence number for the whole group, which should be |
265 | // in sync. If the returned number is not consistent across the group, it |
266 | // may indicate that there is some sort of collective desynchronization. |
267 | uint64_t getSequenceNumberForGroup() override; |
268 | |
269 | static c10::intrusive_ptr<Backend> createProcessGroupUCC( |
270 | const c10::intrusive_ptr<::c10d::Store>& store, |
271 | int rank, |
272 | int size, |
273 | const std::chrono::duration<float>& timeout); |
274 | |
275 | protected: |
276 | const std::chrono::duration<float> timeout_; |
277 | std::shared_ptr<torch_ucc_oob_coll_info_t> oob; |
278 | std::shared_ptr<Comm> comm = {nullptr}; |
279 | uint32_t comm_id; |
280 | ucc_team_h team{nullptr}; |
281 | ucc_ee_h cuda_ee{nullptr}; |
282 | |
283 | #ifdef USE_CUDA |
284 | std::unique_ptr<at::cuda::CUDAStream> stream = nullptr; |
285 | event_pool_t ep; |
286 | #endif |
287 | c10::intrusive_ptr<ProcessGroupUCCLogger> logger; |
288 | }; |
289 | |
290 | class Comm { |
291 | c10::intrusive_ptr<ProcessGroupUCCLogger> logger; |
292 | std::shared_ptr<torch_ucc_oob_coll_info_t> oob; |
293 | CommUCC ucc_comm; |
294 | std::mutex mutex; |
295 | std::thread progress_thread; |
296 | std::condition_variable queue_produce_cv; |
297 | std::condition_variable queue_consume_cv; |
298 | std::deque<std::shared_ptr<ProcessGroupUCC::ProgressEntry>> progress_queue; |
299 | bool stop_progress_loop; |
300 | bool collective_inprogress; |
301 | torch_ucc_phase_t finalize_phase; |
302 | |
303 | public: |
304 | c10::DeviceIndex cuda_device_index; |
305 | Comm( |
306 | const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger, |
307 | std::shared_ptr<torch_ucc_oob_coll_info_t> oob, |
308 | c10::Device dev, |
309 | bool is_health_check); |
310 | |
311 | ~Comm(); |
312 | |
313 | void ucc_create_team( |
314 | ucc_team_h& team, |
315 | std::shared_ptr<torch_ucc_oob_coll_info_t> oob); |
316 | |
317 | void ucc_destroy_team(ucc_team_h& team); |
318 | |
319 | c10::intrusive_ptr<Work> enqueue_p2p( |
320 | OpType opType, |
321 | ucc_coll_req_h request, |
322 | const char* prof_title); |
323 | |
324 | #ifdef USE_CUDA |
325 | void enqueue_cuda_collective( |
326 | std::unique_ptr<ProcessGroupUCC::WorkData> data, |
327 | c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> work, |
328 | ucc_coll_args_t& coll, |
329 | ucc_team_h team, |
330 | ucc_ee_h ee); |
331 | #endif |
332 | |
333 | void enqueue_collective( |
334 | std::unique_ptr<ProcessGroupUCC::WorkData> data, |
335 | c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> work, |
336 | ucc_coll_args_t& coll, |
337 | ucc_team_h team); |
338 | |
339 | static std::shared_ptr<Comm> get_comm( |
340 | uint32_t& id, |
341 | c10::Device dev, |
342 | std::shared_ptr<torch_ucc_oob_coll_info_t> oob, |
343 | const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger, |
344 | bool is_health_check = false); |
345 | |
346 | void progress_loop(); |
347 | }; |
348 | |
349 | } // namespace c10d |
350 | |
351 | #endif // USE_C10D_UCC |
352 | |