1#pragma once
2
3#ifdef USE_C10D_UCC
4
5#include <torch/csrc/distributed/c10d/UCCUtils.hpp>
6
7#include <exception>
8#include <memory>
9#include <mutex>
10#include <queue>
11#include <thread>
12#include <vector>
13
14#include <torch/csrc/distributed/c10d/Backend.hpp>
15#include <torch/csrc/distributed/c10d/Store.hpp>
16#include <torch/csrc/distributed/c10d/Types.hpp>
17#include <torch/csrc/distributed/c10d/Utils.hpp>
18#ifdef USE_CUDA
19#include <ATen/cuda/CUDAEvent.h>
20#include <c10/cuda/CUDAStream.h>
21#endif
22
23namespace c10d {
24
25#define TORCH_UCC_DEVICE_NOT_SET -2
26
27#ifdef USE_CUDA
28#define SAVE_TENSORS(_TENSORS, _DATA) \
29 do { \
30 if ((_TENSORS)[0].device().is_cuda()) { \
31 for (const auto i : c10::irange((_TENSORS).size())) { \
32 c10::cuda::CUDACachingAllocator::recordStream( \
33 (_TENSORS)[i].storage().data_ptr(), (*stream)); \
34 } \
35 } else { \
36 (_DATA) = (_TENSORS); \
37 } \
38 } while (0)
39
40#else
41#define SAVE_TENSORS(_TENSORS, _DATA) (_DATA) = (_TENSORS);
42#endif
43
44constexpr const char* UCC_BACKEND_NAME = "ucc";
45
46struct event_pool_t {
47#ifdef USE_CUDA
48 std::queue<std::unique_ptr<at::cuda::CUDAEvent>> event_pool;
49#endif
50 std::mutex event_pool_mutex;
51};
52
53class Comm;
54
55// UCC does not support multiple CUDA devices per process.
56class TORCH_API ProcessGroupUCC : public Backend {
57 private:
58 void set_timeout(ucc_coll_args_t& args);
59
60 public:
61 class WorkData {
62 public:
63 std::vector<at::Tensor> src;
64 std::vector<at::Tensor> dst;
65 std::vector<at::Tensor> flat;
66 WorkData() {}
67 virtual ~WorkData() = default;
68 };
69 class AlltoallWorkData : public WorkData {
70 public:
71 AlltoallWorkData(int size)
72 : send_lengths(size),
73 send_offsets(size),
74 recv_lengths(size),
75 recv_offsets(size) {}
76 std::vector<uint64_t> send_lengths;
77 std::vector<uint64_t> send_offsets;
78 std::vector<uint64_t> recv_lengths;
79 std::vector<uint64_t> recv_offsets;
80 };
81
82 class AllgathervWorkData : public WorkData {
83 public:
84 AllgathervWorkData(int size) : recv_lengths(size), recv_offsets(size) {}
85 std::vector<uint64_t> recv_lengths;
86 std::vector<uint64_t> recv_offsets;
87 };
88
89 class ScattervWorkData : public WorkData {
90 public:
91 ScattervWorkData(int size) : send_lengths(size), send_offsets(size) {}
92 std::vector<uint64_t> send_lengths;
93 std::vector<uint64_t> send_offsets;
94 };
95
96 class ProgressEntry {
97 friend class ProcessGroupUCC;
98 friend class Comm;
99
100 public:
101 ProgressEntry(CommBase* comm, ucc_coll_req_h request)
102 : status_(UCC_INPROGRESS), comm_(comm), request_(request) {}
103 // Finalizes UCC status or exception of collective request.
104 void finalize(std::exception_ptr eptr = nullptr);
105 ucc_status_t status_;
106 CommBase* comm_;
107 ucc_coll_req_h request_;
108 std::unique_ptr<WorkData> data;
109 c10::intrusive_ptr<c10::ivalue::Future> future_;
110 std::exception_ptr eptr_;
111 };
112
113 class WorkUCC : public Work {
114 friend class ProcessGroupUCC;
115 friend class Comm;
116
117 public:
118 WorkUCC(
119 OpType opType,
120 uint64_t seq,
121 const char* prof_title,
122 const c10::optional<std::vector<at::Tensor>>& inputs,
123 const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger)
124 : Work(-1, opType, prof_title, inputs), logger_(logger), seq_(seq) {}
125 ~WorkUCC();
126 void setException();
127 void setAndThrowException();
128 bool isCompleted() override;
129 bool isSuccess() const override;
130 bool wait(std::chrono::milliseconds timeout = kUnsetTimeout) override;
131 c10::intrusive_ptr<c10::ivalue::Future> getFuture() override;
132 std::vector<at::Tensor> result() override;
133 int sourceRank() const override;
134#ifdef USE_CUDA
135 std::unique_ptr<at::cuda::CUDAEvent> fence = nullptr;
136 event_pool_t* ep = nullptr;
137#endif
138 int sourceRank_;
139
140 protected:
141 std::shared_ptr<ProgressEntry> entry_;
142 c10::intrusive_ptr<ProcessGroupUCCLogger> logger_;
143 uint64_t seq_;
144
145 private:
146 // The future returned by getFuture.
147 c10::intrusive_ptr<at::ivalue::Future> future_;
148 // Store a reference to collective's outputs, used by result
149 std::shared_ptr<std::vector<at::Tensor>> outputs_;
150 };
151
152 explicit ProcessGroupUCC(
153 const c10::intrusive_ptr<Store>& store,
154 int rank = -1,
155 int size = -1,
156 std::chrono::duration<float> timeout = kBackendDefaultTimeout);
157
158 void initComm(c10::Device dev);
159
160 ~ProcessGroupUCC() override;
161
162 const std::string getBackendName() const override {
163 return std::string(UCC_BACKEND_NAME);
164 }
165
166#ifdef USE_CUDA
167 std::unique_ptr<at::cuda::CUDAEvent> getPooledEvent();
168#endif
169
170 // Performs a health check by initializing dummy UCC & UCX communicators and
171 // then destroying them. This will help indicate and signal any
172 // UCC/UCX-related issues prior to the first collective. The actual
173 // initialization and subsequent destruction is ran on a separate thread and
174 // the main thread is signalled about timeouts/errors to report to the
175 // application.
176 void runHealthCheck();
177
178 template <typename PreProcess, typename PostProcess>
179 c10::intrusive_ptr<Work> collective_post(
180 OpType opType,
181 PreProcess preproc,
182 PostProcess postproc,
183 ucc_coll_args_t& coll,
184 std::unique_ptr<ProcessGroupUCC::WorkData> data,
185 c10::Device dev,
186 std::vector<at::Tensor>& inputTensors,
187 std::vector<at::Tensor>& outputTensors,
188 const char* prof_title);
189
190 c10::intrusive_ptr<Work> broadcast(
191 std::vector<at::Tensor>& data,
192 const BroadcastOptions& opts = BroadcastOptions()) override;
193
194 c10::intrusive_ptr<Work> allreduce(
195 std::vector<at::Tensor>& tensors,
196 const AllreduceOptions& opts = AllreduceOptions()) override;
197
198 c10::intrusive_ptr<Work> allreduce_coalesced(
199 std::vector<at::Tensor>& tensors,
200 const AllreduceCoalescedOptions& opts =
201 AllreduceCoalescedOptions()) override;
202
203 c10::intrusive_ptr<Work> reduce(
204 std::vector<at::Tensor>& tensors,
205 const ReduceOptions& opts = ReduceOptions()) override;
206
207 c10::intrusive_ptr<Work> allgather(
208 std::vector<std::vector<at::Tensor>>& outputTensors,
209 std::vector<at::Tensor>& inputTensors,
210 const AllgatherOptions& opts = AllgatherOptions()) override;
211
212 c10::intrusive_ptr<Work> _allgather_base(
213 at::Tensor& outputBuffer,
214 at::Tensor& inputBuffer,
215 const AllgatherOptions& opts = AllgatherOptions()) override;
216
217 c10::intrusive_ptr<Work> barrier(
218 const BarrierOptions& opts = BarrierOptions()) override;
219
220 c10::intrusive_ptr<Work> gather(
221 std::vector<std::vector<at::Tensor>>& outputTensors,
222 std::vector<at::Tensor>& inputTensors,
223 const GatherOptions& opts = GatherOptions()) override;
224
225 c10::intrusive_ptr<Work> scatter(
226 std::vector<at::Tensor>& outputTensors,
227 std::vector<std::vector<at::Tensor>>& inputTensors,
228 const ScatterOptions& opts = ScatterOptions()) override;
229
230 c10::intrusive_ptr<Work> reduce_scatter(
231 std::vector<at::Tensor>& outputTensors,
232 std::vector<std::vector<at::Tensor>>& inputTensors,
233 const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
234
235 c10::intrusive_ptr<Work> alltoall_base(
236 at::Tensor& outputTensor,
237 at::Tensor& inputTensor,
238 std::vector<int64_t>& outputSplitSizes,
239 std::vector<int64_t>& inputSplitSizes,
240 const AllToAllOptions& opts = AllToAllOptions()) override;
241
242 c10::intrusive_ptr<Work> alltoall(
243 std::vector<at::Tensor>& outputTensors,
244 std::vector<at::Tensor>& inputTensors,
245 const AllToAllOptions& opts = AllToAllOptions()) override;
246
247 c10::intrusive_ptr<Work> send(
248 std::vector<at::Tensor>& tensors,
249 int dstRank,
250 int tag) override;
251
252 c10::intrusive_ptr<Work> recv(
253 std::vector<at::Tensor>& tensors,
254 int srcRank,
255 int tag) override;
256
257 // Counting for the sequential number of UCC collective_post call.
258 uint64_t seq_{0};
259
260 // Agrees on an initial sequence number for the whole group by having rank 0
261 // create it and broadcast it to other ranks using the store.
262 void setSequenceNumberForGroup() override;
263
264 // Retrieves the current sequence number for the whole group, which should be
265 // in sync. If the returned number is not consistent across the group, it
266 // may indicate that there is some sort of collective desynchronization.
267 uint64_t getSequenceNumberForGroup() override;
268
269 static c10::intrusive_ptr<Backend> createProcessGroupUCC(
270 const c10::intrusive_ptr<::c10d::Store>& store,
271 int rank,
272 int size,
273 const std::chrono::duration<float>& timeout);
274
275 protected:
276 const std::chrono::duration<float> timeout_;
277 std::shared_ptr<torch_ucc_oob_coll_info_t> oob;
278 std::shared_ptr<Comm> comm = {nullptr};
279 uint32_t comm_id;
280 ucc_team_h team{nullptr};
281 ucc_ee_h cuda_ee{nullptr};
282
283#ifdef USE_CUDA
284 std::unique_ptr<at::cuda::CUDAStream> stream = nullptr;
285 event_pool_t ep;
286#endif
287 c10::intrusive_ptr<ProcessGroupUCCLogger> logger;
288};
289
290class Comm {
291 c10::intrusive_ptr<ProcessGroupUCCLogger> logger;
292 std::shared_ptr<torch_ucc_oob_coll_info_t> oob;
293 CommUCC ucc_comm;
294 std::mutex mutex;
295 std::thread progress_thread;
296 std::condition_variable queue_produce_cv;
297 std::condition_variable queue_consume_cv;
298 std::deque<std::shared_ptr<ProcessGroupUCC::ProgressEntry>> progress_queue;
299 bool stop_progress_loop;
300 bool collective_inprogress;
301 torch_ucc_phase_t finalize_phase;
302
303 public:
304 c10::DeviceIndex cuda_device_index;
305 Comm(
306 const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger,
307 std::shared_ptr<torch_ucc_oob_coll_info_t> oob,
308 c10::Device dev,
309 bool is_health_check);
310
311 ~Comm();
312
313 void ucc_create_team(
314 ucc_team_h& team,
315 std::shared_ptr<torch_ucc_oob_coll_info_t> oob);
316
317 void ucc_destroy_team(ucc_team_h& team);
318
319 c10::intrusive_ptr<Work> enqueue_p2p(
320 OpType opType,
321 ucc_coll_req_h request,
322 const char* prof_title);
323
324#ifdef USE_CUDA
325 void enqueue_cuda_collective(
326 std::unique_ptr<ProcessGroupUCC::WorkData> data,
327 c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> work,
328 ucc_coll_args_t& coll,
329 ucc_team_h team,
330 ucc_ee_h ee);
331#endif
332
333 void enqueue_collective(
334 std::unique_ptr<ProcessGroupUCC::WorkData> data,
335 c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> work,
336 ucc_coll_args_t& coll,
337 ucc_team_h team);
338
339 static std::shared_ptr<Comm> get_comm(
340 uint32_t& id,
341 c10::Device dev,
342 std::shared_ptr<torch_ucc_oob_coll_info_t> oob,
343 const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger,
344 bool is_health_check = false);
345
346 void progress_loop();
347};
348
349} // namespace c10d
350
351#endif // USE_C10D_UCC
352