1#ifdef USE_C10D_UCC
2
3#include <torch/csrc/distributed/c10d/ProcessGroupUCC.hpp>
4#include <torch/csrc/distributed/c10d/UCCTracing.hpp>
5#include <torch/csrc/distributed/c10d/UCCUtils.hpp>
6#include <list>
7#include <memory>
8#include <unordered_map>
9#include <unordered_set>
10
11namespace c10d {
12
13namespace {
14constexpr int64_t kBusyWaitMillis = 10;
15
16const std::map<c10::DeviceType, ucc_memory_type_t> ucc_mtype_map = {
17 {c10::kCPU, UCC_MEMORY_TYPE_HOST},
18 {c10::kCUDA, UCC_MEMORY_TYPE_CUDA},
19};
20
21ucc_memory_type_t to_ucc_memType(c10::DeviceType _c10_type) {
22 if (ucc_mtype_map.find(_c10_type) != ucc_mtype_map.end())
23 return ucc_mtype_map.at(_c10_type);
24 else
25 return UCC_MEMORY_TYPE_UNKNOWN;
26}
27
28const std::map<at::ScalarType, ucc_datatype_t> ucc_dtype_map = {
29 {at::kByte, UCC_DT_UINT8},
30 {at::kChar, UCC_DT_INT8},
31 {at::kHalf, UCC_DT_FLOAT16},
32 {at::kBFloat16, UCC_DT_BFLOAT16},
33 {at::kDouble, UCC_DT_FLOAT64},
34 {at::kFloat, UCC_DT_FLOAT32},
35 {at::kInt, UCC_DT_INT32},
36 {at::kLong, UCC_DT_INT64},
37 {at::kBool, UCC_DT_UINT8},
38};
39
40ucc_datatype_t to_ucc_dType(at::Tensor _tensor) {
41 if (_tensor.scalar_type() == at::kBool && _tensor.element_size() != 1) {
42 TORCH_CHECK(
43 false, "Size of Boolean type larger than 1 is not supported in UCC");
44 }
45 try {
46 return ucc_dtype_map.at(_tensor.scalar_type());
47 } catch (const std::out_of_range& e) {
48 TORCH_CHECK(false, "Not supported data type for UCC");
49 }
50}
51
52const std::map<ReduceOp, ucc_reduction_op_t> ucc_op_map = {
53 {ReduceOp::SUM, UCC_OP_SUM},
54 {ReduceOp::PRODUCT, UCC_OP_PROD},
55 {ReduceOp::MIN, UCC_OP_MIN},
56 {ReduceOp::MAX, UCC_OP_MAX},
57 {ReduceOp::BAND, UCC_OP_BAND},
58 {ReduceOp::BOR, UCC_OP_BOR},
59 {ReduceOp::BXOR, UCC_OP_BXOR},
60 {ReduceOp::AVG, UCC_OP_AVG},
61};
62
63ucc_reduction_op_t to_ucc_reduceOp(
64 const ReduceOp _op,
65 const at::ScalarType _dt) {
66 if (_dt == at::kBool) {
67 if (_op == ReduceOp::SUM) {
68 // bitwise or
69 return UCC_OP_MAX;
70 } else if (_op == ReduceOp::PRODUCT) {
71 // bitwise and
72 return UCC_OP_MIN;
73 } else if (_op == ReduceOp::AVG) {
74 TORCH_CHECK(false, "Cannot use ReduceOp.AVG with boolean inputs");
75 }
76 }
77
78 try {
79 return ucc_op_map.at(_op);
80 } catch (const std::out_of_range& e) {
81 TORCH_CHECK(false, "Not supported ReduceOp for UCC");
82 }
83}
84
85struct torch_ucc_config_t {
86 c10::once_flag flag;
87 std::array<bool, 32> blocking_wait;
88 bool enable_comms_logger;
89 bool use_future;
90 // Sharing UCC communicator among multiple PGs to save resource.
91 bool shared_comm;
92 // Using allgatherv to achieve allgather, without flattening the list of
93 // (potentially non-contiguous) tensors.
94 bool use_allgatherv;
95 bool enable_health_check;
96} torch_ucc_config;
97
98std::unordered_map<std::string, std::string> torch_ucc_envs_map = {
99 // TORCH_UCC_BLOCKING_WAIT allowed syntax:
100 // - TORCH_UCC_BLOCKING_WAIT=none --> blocking wait completely disabled
101 // - TORCH_UCC_BLOCKING_WAIT=all --> blocking wait completely enabled
102 // - TORCH_UCC_BLOCKING_WAIT=allreduce,send,recv --> blocking wait enabled
103 // on selected operations
104 // Supported operations:
105 // [allgather,allgather_base,allreduce,alltoall,broadcast,
106 // gather,reduce,reduce_scatter,scatter,send,recv]
107 {"TORCH_UCC_BLOCKING_WAIT", "none"},
108
109 {"TORCH_UCC_USE_FUTURE", "1"},
110 {"TORCH_UCC_PROFILING_ENABLE", "0"},
111 {"TORCH_UCC_SHARED_COMM", "1"},
112 {"TORCH_UCC_USE_ALLGATHERV", "0"},
113 {"TORCH_UCC_ENABLE_HEALTH_CHECK", "0"},
114 {"TORCH_UCC_ENABLE_COMMS_LOGGER", "0"},
115};
116
117std::vector<OpType> parse_blocking_wait(std::string op_list_string) {
118 const static std::unordered_map<std::string, OpType> str2op = {
119 {"allgather", OpType::ALLGATHER},
120 {"allgather_base", OpType::_ALLGATHER_BASE},
121 {"allreduce", OpType::ALLREDUCE},
122 {"alltoall_base", OpType::ALLTOALL_BASE},
123 {"broadcast", OpType::BROADCAST},
124 {"gather", OpType::GATHER},
125 {"reduce", OpType::REDUCE},
126 {"reduce_scatter", OpType::REDUCE_SCATTER},
127 {"scatter", OpType::SCATTER},
128 {"send", OpType::SEND},
129 {"recv", OpType::RECV},
130 };
131 auto op_list = parse_list(op_list_string);
132 if (op_list == std::vector<std::string>{"none"}) {
133 return {};
134 }
135 std::vector<OpType> result;
136 if (op_list == std::vector<std::string>{"all"}) {
137 for (auto entry : str2op) {
138 result.push_back(entry.second);
139 }
140 } else {
141 for (auto op_string : op_list) {
142 result.push_back(str2op.at(op_string));
143 }
144 }
145 return result;
146}
147
148} // namespace
149
150void read_config() {
151 // default configuration
152 torch_ucc_config.blocking_wait.fill(false);
153 torch_ucc_config.use_future = true;
154 torch_ucc_config.shared_comm = false;
155 torch_ucc_config.use_allgatherv = false;
156 torch_ucc_config.enable_health_check = false;
157 torch_ucc_config.enable_comms_logger = false;
158
159 // read all torch_ucc env. variables and update the map
160 char* env;
161 for (auto& torch_ucc_env : torch_ucc_envs_map) {
162 env = std::getenv(torch_ucc_env.first.c_str());
163 if (env) {
164 torch_ucc_envs_map[torch_ucc_env.first] = std::string(env);
165 }
166 }
167
168 auto blocking_wait_str = torch_ucc_envs_map.at("TORCH_UCC_BLOCKING_WAIT");
169 for (auto op : parse_blocking_wait(blocking_wait_str)) {
170 torch_ucc_config.blocking_wait[(std::uint8_t)op] = true;
171 }
172 // barrier is always blocking
173 torch_ucc_config.blocking_wait[(std::uint8_t)OpType::BARRIER] = true;
174
175 // barrier is always blocking
176 torch_ucc_config.blocking_wait[(std::uint8_t)OpType::BARRIER] = true;
177
178 torch_ucc_config.use_future =
179 std::stoi(torch_ucc_envs_map.at("TORCH_UCC_USE_FUTURE"));
180 torch_ucc_config.shared_comm =
181 std::stoi(torch_ucc_envs_map.at("TORCH_UCC_SHARED_COMM"));
182 torch_ucc_config.use_allgatherv =
183 std::stoi(torch_ucc_envs_map.at("TORCH_UCC_USE_ALLGATHERV"));
184 torch_ucc_config.enable_health_check =
185 std::stoi(torch_ucc_envs_map.at("TORCH_UCC_ENABLE_HEALTH_CHECK"));
186 torch_ucc_config.enable_comms_logger =
187 std::stoi(torch_ucc_envs_map.at("TORCH_UCC_ENABLE_COMMS_LOGGER"));
188}
189
190void check_device(c10::Device dev1, c10::Device dev2) {
191 if (dev1.is_cuda() && dev2.is_cuda() && dev1 != dev2) {
192 throw std::runtime_error("ProcessGroupUCC multidevice is not supported");
193 }
194}
195
196void check_tensor(const std::vector<at::Tensor>& tensors) {
197 if (tensors.size() != 1) {
198 throw std::runtime_error(
199 "ProcessGroupUCC takes 1 tensor. Got " +
200 std::to_string(tensors.size()) + ". ");
201 }
202 if (!tensors[0].is_contiguous()) {
203 throw std::runtime_error(
204 "ProcessGroupUCC input tensor has to be contiguous");
205 }
206 if (tensors[0].is_sparse()) {
207 throw std::runtime_error("ProcessGroupUCC input tensor has to be dense");
208 }
209 // TODO: check cuda case
210}
211
212ProcessGroupUCC::WorkUCC::~WorkUCC() {
213#ifdef USE_CUDA
214 if (fence && ep) {
215 std::lock_guard<std::mutex> lock(ep->event_pool_mutex);
216 ep->event_pool.push(std::move(fence));
217 }
218#endif
219}
220
221void ProcessGroupUCC::WorkUCC::setException() {
222 if (exception() || !entry_) {
223 return;
224 }
225 exception_ = entry_->eptr_;
226}
227
228void ProcessGroupUCC::WorkUCC::setAndThrowException() {
229 setException();
230 if (exception()) {
231 std::rethrow_exception(exception());
232 }
233}
234
235bool ProcessGroupUCC::WorkUCC::isCompleted() {
236 if (!entry_) {
237 return true;
238 }
239 setException();
240 // status_ <= 0 to avoid listing all possible status codes. The main thread
241 // needs to be unblocked when UCC (in progress thread) returns success (== 0)
242 // or any error code (< 0).
243 return exception() || entry_->status_ <= 0;
244}
245
246bool ProcessGroupUCC::WorkUCC::isSuccess() const {
247 if (!entry_) {
248 return true;
249 }
250 return !exception() && entry_->status_ == 0;
251}
252
253bool ProcessGroupUCC::WorkUCC::wait(std::chrono::milliseconds /* unused */) {
254 if (torch_ucc_config.enable_comms_logger && logger_) {
255 logger_->trace_generator->recordComms("wait", (uintptr_t)this, rank_);
256 }
257#ifdef USE_CUDA
258 if (fence && !torch_ucc_config.blocking_wait[(int)opType_]) {
259 // block user stream
260 setAndThrowException();
261 fence->block(at::cuda::getCurrentCUDAStream());
262 return true;
263 }
264#endif
265 // wait for complete. For blocking case, the main thread will be blocked in
266 // this loop until the progress thread changes the status of this request.
267 // If timeout occurs, UCC will return UCC_ERR_TIMEOUT as the status. The
268 // main thread will throw out the exception then. There is no "abort"
269 // function in UCC currently.
270 while (!isCompleted())
271 ;
272 setAndThrowException();
273 // manually call profiling end callbacks if they are set,
274 // since progress thread does not own WorkUCC
275 if (Work::recordFunctionEndCallback_) {
276 Work::recordFunctionEndCallback_();
277 Work::recordFunctionEndCallback_ = nullptr;
278 }
279 return true;
280}
281
282c10::intrusive_ptr<c10::ivalue::Future> ProcessGroupUCC::WorkUCC::getFuture() {
283 return future_;
284}
285
286int ProcessGroupUCC::WorkUCC::sourceRank() const {
287 if (opType_ != OpType::RECV && opType_ != OpType::RECVANYSOURCE) {
288 // Throw an error
289 return Work::sourceRank();
290 }
291 return sourceRank_;
292}
293
294std::vector<at::Tensor> ProcessGroupUCC::WorkUCC::result() {
295 return *outputs_;
296}
297
298void ProcessGroupUCC::ProgressEntry::finalize(std::exception_ptr eptr) {
299 ucc_status_t status = UCC_OK;
300
301 if (request_ != nullptr) {
302 status = request_->status;
303 comm_->free_request(request_);
304 }
305 if (eptr) {
306 eptr_ = eptr;
307 } else {
308 status_ = status;
309 }
310 if (future_) {
311 if (eptr) {
312 future_->setError(eptr);
313 } else {
314 future_->markCompleted(
315 c10::IValue(data ? data->dst : std::vector<at::Tensor>()));
316 }
317 }
318}
319
320Comm::Comm(
321 const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger_,
322 std::shared_ptr<torch_ucc_oob_coll_info_t> oob_,
323 c10::Device dev,
324 bool is_health_check)
325 : logger(logger_),
326 oob(oob_),
327 ucc_comm(oob, logger),
328 finalize_phase(
329 is_health_check ? TORCH_UCC_HEALTH_CHECK : TORCH_UCC_FINALIZE),
330 cuda_device_index(TORCH_UCC_DEVICE_NOT_SET) {
331 if (dev.is_cuda()) {
332 cuda_device_index = dev.index();
333 }
334 stop_progress_loop = false;
335 collective_inprogress = false;
336 progress_thread = std::thread(&Comm::progress_loop, this);
337#ifdef _GNU_SOURCE
338 pthread_setname_np(progress_thread.native_handle(), "ucc-progress");
339#endif
340}
341
342Comm::~Comm() {
343 std::unique_lock<std::mutex> lock(mutex);
344 queue_consume_cv.wait(
345 lock, [&] { return progress_queue.empty() && !collective_inprogress; });
346 stop_progress_loop = true;
347 lock.unlock();
348 queue_produce_cv.notify_all();
349 progress_thread.join();
350}
351
352std::shared_ptr<Comm> Comm::get_comm(
353 uint32_t& id,
354 c10::Device dev,
355 std::shared_ptr<torch_ucc_oob_coll_info_t> oob,
356 const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger,
357 bool is_health_check) {
358 static std::mutex m;
359 static std::weak_ptr<Comm> comm;
360 static uint32_t comm_id;
361
362 std::lock_guard<std::mutex> lock(m);
363 id = comm_id;
364
365 std::string group_id = "group_id";
366 if (is_health_check) {
367 group_id = c10::str(dev.type()) + "/" + group_id;
368 }
369
370 std::vector<uint8_t> remote_comm_id;
371 oob->store->deleteKey(group_id + std::to_string(0));
372 if (oob->rank != 0) {
373 std::vector<uint8_t> val = std::vector<uint8_t>(
374 reinterpret_cast<uint8_t*>(&id),
375 reinterpret_cast<uint8_t*>(&id) + sizeof(id));
376 oob->store->set(group_id + std::to_string(oob->rank), val);
377 } else {
378 for (int i = 1; i < oob->size; i++) {
379 remote_comm_id = oob->store->get(group_id + std::to_string(i));
380 oob->store->deleteKey(group_id + std::to_string(i));
381 // Find the highest id.
382 id = std::max(id, *(reinterpret_cast<uint32_t*>(remote_comm_id.data())));
383 }
384 std::vector<uint8_t> val = std::vector<uint8_t>(
385 reinterpret_cast<uint8_t*>(&id),
386 reinterpret_cast<uint8_t*>(&id) + sizeof(id));
387 oob->store->set(group_id + std::to_string(oob->rank), val);
388 }
389 remote_comm_id = oob->store->get(group_id + std::to_string(0));
390 oob->comm_id = *(reinterpret_cast<uint32_t*>(remote_comm_id.data()));
391 // Prepare comm_id (static variable) to the next id.
392 comm_id = oob->comm_id + 1;
393
394 if (torch_ucc_config.shared_comm) {
395 std::shared_ptr<Comm> shared_comm = comm.lock();
396 if (!shared_comm) {
397 shared_comm = std::make_shared<Comm>(logger, oob, dev, is_health_check);
398 comm = shared_comm;
399 } else {
400 if (dev.is_cuda() && !is_health_check) {
401 if ((shared_comm->cuda_device_index != TORCH_UCC_DEVICE_NOT_SET) &&
402 (shared_comm->cuda_device_index != dev.index())) {
403 TORCH_UCC_LOG_ERROR(
404 is_health_check ? TORCH_UCC_HEALTH_CHECK : TORCH_UCC_INIT,
405 "ucc communicator was initialized with different cuda device,"
406 "multi device is not supported");
407 throw std::runtime_error(ucc_status_string(UCC_ERR_NOT_SUPPORTED));
408 }
409 shared_comm->cuda_device_index = dev.index();
410 }
411 }
412 return shared_comm;
413 } else {
414 return std::make_shared<Comm>(logger, oob, dev, is_health_check);
415 }
416}
417
418void Comm::ucc_create_team(
419 ucc_team_h& team,
420 std::shared_ptr<torch_ucc_oob_coll_info_t> oob) {
421 ucc_status_t st;
422 ucc_team_params_t team_params;
423 team_params.mask = UCC_TEAM_PARAM_FIELD_EP | UCC_TEAM_PARAM_FIELD_EP_RANGE |
424 UCC_TEAM_PARAM_FIELD_OOB;
425 team_params.oob.allgather = oob_allgather;
426 team_params.oob.req_test = oob_allgather_test;
427 team_params.oob.req_free = oob_allgather_free;
428 team_params.oob.coll_info = oob.get();
429 team_params.oob.n_oob_eps = oob->size;
430 team_params.oob.oob_ep = oob->rank;
431 team_params.ep = oob->rank;
432 team_params.ep_range = UCC_COLLECTIVE_EP_RANGE_CONTIG;
433 TORCH_UCC_CHECK(
434 ucc_team_create_post(&ucc_comm.context, 1, &team_params, &team),
435 "failed to post team create");
436 do {
437 st = ucc_team_create_test(team);
438 ucc_context_progress(ucc_comm.context);
439 } while (st == UCC_INPROGRESS);
440 TORCH_UCC_CHECK(st, "failed to create UCC team");
441}
442
443void Comm::ucc_destroy_team(ucc_team_h& team) {
444 std::unique_lock<std::mutex> lock(mutex);
445 queue_consume_cv.wait(
446 lock, [&] { return progress_queue.empty() && !collective_inprogress; });
447
448 ucc_status_t status;
449 while (UCC_INPROGRESS == (status = ucc_team_destroy(team))) {
450 if (UCC_OK != status) {
451 TORCH_UCC_LOG_ERROR(
452 finalize_phase,
453 c10::str("ucc team destroy error: ", ucc_status_string(status)));
454 break;
455 }
456 }
457
458 lock.unlock();
459}
460
461void Comm::enqueue_collective(
462 std::unique_ptr<ProcessGroupUCC::WorkData> data,
463 c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> work,
464 ucc_coll_args_t& coll,
465 ucc_team_h team) {
466 ucc_coll_req_h request;
467 TORCH_UCC_CHECK(
468 ucc_collective_init(&coll, &request, team), "failed to init collective");
469 TORCH_UCC_CHECK_REQUEST(
470 request, ucc_collective_post(request), "failed to post collective");
471
472 auto entry =
473 std::make_shared<ProcessGroupUCC::ProgressEntry>(&ucc_comm, request);
474 entry->data = std::move(data);
475 entry->future_ = work->getFuture();
476 work->entry_ = entry;
477 std::unique_lock<std::mutex> lock(mutex);
478 progress_queue.push_back(entry);
479 lock.unlock();
480 queue_produce_cv.notify_one();
481}
482
483#ifdef USE_CUDA
484void Comm::enqueue_cuda_collective(
485 std::unique_ptr<ProcessGroupUCC::WorkData> data,
486 c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> work,
487 ucc_coll_args_t& coll,
488 ucc_team_h team,
489 ucc_ee_h ee) {
490 ucc_coll_req_h request;
491 TORCH_UCC_CHECK(
492 ucc_collective_init(&coll, &request, team),
493 "failed to init cuda collective");
494 ucc_ev_t comp_ev, *post_ev;
495 comp_ev.ev_type = UCC_EVENT_COMPUTE_COMPLETE;
496 comp_ev.ev_context = nullptr;
497 comp_ev.ev_context_size = 0;
498 comp_ev.req = request;
499 TORCH_UCC_CHECK_REQUEST(
500 request,
501 ucc_collective_triggered_post(ee, &comp_ev),
502 "failed to post triggered collective");
503 ucc_status_t st = ucc_ee_get_event(ee, &post_ev);
504 TORCH_CHECK(st == UCC_OK && post_ev->ev_type == UCC_EVENT_COLLECTIVE_POST);
505 ucc_ee_ack_event(ee, post_ev);
506 auto entry =
507 std::make_shared<ProcessGroupUCC::ProgressEntry>(&ucc_comm, request);
508 entry->data = std::move(data);
509 work->entry_ = entry;
510 std::unique_lock<std::mutex> lock(mutex);
511 progress_queue.push_back(entry);
512 lock.unlock();
513 queue_produce_cv.notify_one();
514}
515#endif
516
517void Comm::progress_loop() {
518 std::unique_lock<std::mutex> lock(mutex);
519#ifdef USE_CUDA
520 bool device_set = false;
521#endif
522 while (!stop_progress_loop) {
523 if (progress_queue.empty()) {
524 queue_produce_cv.wait(lock);
525 continue;
526 }
527 collective_inprogress = true;
528 auto work = progress_queue.front();
529 progress_queue.pop_front();
530 lock.unlock();
531#ifdef USE_CUDA
532 if ((!device_set) && (cuda_device_index != TORCH_UCC_DEVICE_NOT_SET)) {
533 c10::cuda::set_device(cuda_device_index);
534 device_set = true;
535 }
536#endif
537 std::exception_ptr eptr;
538 try {
539 while (work->request_->status > 0) {
540 ucc_comm.progress();
541 }
542 if (work->request_->status < 0) {
543 eptr = std::make_exception_ptr(
544 std::runtime_error(ucc_status_string(work->request_->status)));
545 std::string err_log = c10::str(
546 "Failed to progress communication", // TODO: report exact op type or
547 // id?
548 ucc_status_string(work->request_->status));
549 TORCH_UCC_LOG_ERROR(TORCH_UCC_COLL_PROGRESS, err_log);
550 }
551 } catch (...) {
552 eptr = std::current_exception();
553 }
554 work->finalize(eptr);
555 work = nullptr;
556 collective_inprogress = false;
557 queue_consume_cv.notify_one();
558 lock.lock();
559 }
560}
561
562ProcessGroupUCC::ProcessGroupUCC(
563 const c10::intrusive_ptr<Store>& store,
564 int rank,
565 int size,
566 std::chrono::duration<float> timeout)
567 : Backend(rank, size), timeout_(timeout) {
568 c10::call_once(torch_ucc_config.flag, read_config);
569 oob = std::make_shared<torch_ucc_oob_coll_info_t>();
570 oob->rank = rank;
571 oob->size = size;
572 oob->store = store;
573 comm = nullptr;
574 cuda_ee = nullptr;
575 static uint32_t id = 0;
576 uint32_t pg_id = id++;
577
578 logger = c10::make_intrusive<ProcessGroupUCCLogger>(
579 c10::str("[Rank ", rank_, "]", "[ProcessGroupUCC-", pg_id, "]"),
580 TORCH_UCC_INIT);
581 TORCH_UCC_LOG_INFO(
582 TORCH_UCC_INIT,
583 c10::str(
584 "Created ProcessGroupUCC with ",
585 size,
586 " ranks, with timeout ",
587 timeout_.count(),
588 " secs"));
589 std::string envs = "";
590 for (auto& torch_ucc_env : torch_ucc_envs_map) {
591 envs += ("\n\t" + torch_ucc_env.first + "=" + torch_ucc_env.second);
592 }
593 TORCH_UCC_LOG_INFO(
594 TORCH_UCC_INIT,
595 c10::str(
596 "Successfully read and set ProcessGroupUCC env. variables as followings",
597 envs));
598
599 if (torch_ucc_config.enable_health_check) {
600 // Perform health check by initializing dummy communicators and destroying
601 // them. This will help indicate any UCC/UCX-related issues prior to the
602 // first collective. Run it in a separate thread and wait on CV to handle
603 // timeouts so that if there are hangs, the main thread can still run
604 // correctly.
605 runHealthCheck();
606 }
607 if (torch_ucc_config.enable_comms_logger) {
608 logger->initCommsTracer();
609 }
610}
611
612ProcessGroupUCC::~ProcessGroupUCC() {
613 if (torch_ucc_config.enable_comms_logger) {
614 logger->flushComms(this->getRank(), this->getSize());
615 }
616 if (comm) {
617 logger->setPhase(TORCH_UCC_FINALIZE);
618 comm->ucc_destroy_team(team);
619 TORCH_UCC_LOG_INFO(
620 TORCH_UCC_FINALIZE, "Successfully destroyed UCC library");
621 try {
622 if (cuda_ee) {
623 ucc_ee_destroy(cuda_ee);
624 }
625 } catch (std::exception& ex) {
626 TORCH_UCC_LOG_INFO(
627 TORCH_UCC_FINALIZE,
628 c10::str(
629 "(~ProcessGroupUCC) Caught error in Store Operation .. ",
630 "[",
631 ex.what(),
632 "]"));
633 }
634 comm = nullptr;
635 }
636}
637
638#ifdef USE_CUDA
639// Return CUDA device with ordinal given by input rank.
640c10::Device getCUDADeviceForRank(int rank) {
641 TORCH_CHECK(rank >= 0, "Invalid rank ", rank);
642 auto numGPUs = at::cuda::getNumGPUs();
643 auto deviceIdx = static_cast<c10::DeviceIndex>(rank % numGPUs);
644 return c10::Device(c10::DeviceType::CUDA, deviceIdx);
645}
646#endif
647
648void ProcessGroupUCC::runHealthCheck() {
649 // Run health check in a separate thread and wait on CV to handle timeouts.
650 // This design allows us to handle hangs.
651
652 // When size_ is 1, there is no need to do any communication at all.
653 if (size_ == 1)
654 return;
655
656 struct HealthCheckData {
657 std::mutex healthCheckMutex;
658 std::condition_variable healthCheckCv;
659 bool uccHealthCheckSuccess = false;
660 std::exception_ptr healthCheckException;
661 } healthCheckData;
662
663 auto t = std::thread([&healthCheckData, this]() {
664 std::list<c10::Device> devices{c10::kCPU};
665#ifdef USE_CUDA
666 c10::cuda::OptionalCUDAGuard gpuGuard;
667 if (at::cuda::is_available()) {
668 devices.emplace_front(getCUDADeviceForRank(rank_));
669 }
670#endif
671 for (auto device : devices) {
672 bool is_last_device = (device == devices.back());
673 try {
674 auto oob = std::make_shared<torch_ucc_oob_coll_info_t>();
675 oob->rank = this->oob->rank;
676 oob->size = this->oob->size;
677 oob->store = this->oob->store;
678 ucc_team_h team = nullptr;
679 uint32_t comm_id;
680#ifdef USE_CUDA
681 if (device.is_cuda()) {
682 gpuGuard.set_index(device.index());
683 }
684#endif
685 auto comm = Comm::get_comm(comm_id, device, oob, logger, true);
686 comm->ucc_create_team(team, oob);
687 comm->ucc_destroy_team(team);
688 TORCH_UCC_LOG_INFO(
689 TORCH_UCC_HEALTH_CHECK,
690 c10::str(
691 "UCC library health check succeed for device ",
692 c10::DeviceTypeName(device.type())));
693 // Mark ucc health check as complete.
694 if (is_last_device) {
695 std::lock_guard<std::mutex> lk(healthCheckData.healthCheckMutex);
696 healthCheckData.uccHealthCheckSuccess = true;
697 }
698
699 comm = nullptr;
700 oob = nullptr;
701 // Notify main thread the health check is complete.
702 if (is_last_device) {
703 healthCheckData.healthCheckCv.notify_one();
704 }
705 } catch (const std::exception& e) {
706 // Populate exception ptr.
707 healthCheckData.healthCheckException = std::current_exception();
708 // Unblock waiting main thread which will report exception.
709 healthCheckData.healthCheckCv.notify_one();
710 } // Unknown exceptions will just cause the program to terminate.
711 }
712 });
713 // We don't need to join the thread, just need to verify health check via the
714 // CV. Hence we detach the thread here.
715 t.detach(); // NOLINT
716 TORCH_UCC_LOG_INFO(
717 TORCH_UCC_HEALTH_CHECK,
718 c10::str(
719 "will wait up to ",
720 timeout_.count(),
721 " msec for UCC health check to complete."));
722 std::unique_lock<std::mutex> lock(healthCheckData.healthCheckMutex);
723 healthCheckData.healthCheckCv.wait_for(lock, timeout_, [&healthCheckData]() {
724 return healthCheckData.uccHealthCheckSuccess;
725 });
726
727 if (healthCheckData.healthCheckException) {
728 std::rethrow_exception(healthCheckData.healthCheckException);
729 }
730 // If there is no exception, the likely culprit is a timeout/hang
731 TORCH_CHECK(
732 healthCheckData.uccHealthCheckSuccess,
733 "ProcessGroupUCC: Health check failure: Failed to initialize UCC on rank ",
734 rank_);
735}
736
737void ProcessGroupUCC::set_timeout(ucc_coll_args_t& args) {
738 args.mask |= UCC_COLL_ARGS_FIELD_FLAGS;
739 args.flags |= UCC_COLL_ARGS_FLAG_TIMEOUT;
740 args.timeout = timeout_.count();
741}
742
743#ifdef USE_CUDA
744std::unique_ptr<at::cuda::CUDAEvent> ProcessGroupUCC::getPooledEvent() {
745 std::unique_ptr<at::cuda::CUDAEvent> ev;
746 std::lock_guard<std::mutex> lock(ep.event_pool_mutex);
747 if (ep.event_pool.empty()) {
748 ev = std::make_unique<at::cuda::CUDAEvent>();
749 } else {
750 ev = std::move(ep.event_pool.front());
751 ep.event_pool.pop();
752 }
753 return ev;
754}
755#endif
756
757template <typename PreProcess, typename PostProcess>
758c10::intrusive_ptr<Work> ProcessGroupUCC::collective_post(
759 OpType opType,
760 PreProcess preproc,
761 PostProcess postproc,
762 ucc_coll_args_t& coll,
763 std::unique_ptr<ProcessGroupUCC::WorkData> data,
764 c10::Device dev,
765 std::vector<at::Tensor>& inputTensors,
766 std::vector<at::Tensor>& outputTensors,
767 const char* prof_title) {
768 seq_++;
769 set_timeout(coll);
770 auto work = c10::make_intrusive<ProcessGroupUCC::WorkUCC>(
771 opType, seq_, prof_title, inputTensors, logger);
772
773 if (opType == OpType::RECV) {
774 work->sourceRank_ = coll.root;
775 }
776
777 RECORD_COMMS_TRACE(
778 logger->trace_generator,
779 work,
780 opType,
781 this->getRank(),
782 this->getSize(),
783 inputTensors,
784 outputTensors);
785
786 // Store references to outputs to be used by result
787 work->outputs_ = std::make_shared<std::vector<at::Tensor>>(outputTensors);
788 switch (dev.type()) {
789 case c10::DeviceType::CPU: {
790 if (torch_ucc_config.use_future) {
791 work->future_ = c10::make_intrusive<at::ivalue::Future>(
792 c10::ListType::create(c10::TensorType::get()));
793 }
794 preproc();
795 comm->enqueue_collective(std::move(data), work, coll, team);
796 postproc();
797 return work;
798 }
799#ifdef USE_CUDA
800 case c10::DeviceType::CUDA: {
801 auto cuda_ev = getPooledEvent();
802 cuda_ev->record(at::cuda::getCurrentCUDAStream(dev.index()));
803 cuda_ev->block(*stream);
804 at::cuda::CUDAStreamGuard guard(*stream);
805 preproc();
806 comm->enqueue_cuda_collective(std::move(data), work, coll, team, cuda_ee);
807 postproc();
808 cuda_ev->record(*stream);
809 work->fence = std::move(cuda_ev);
810 work->ep = &ep;
811 if (torch_ucc_config.use_future) {
812 c10::cuda::CUDAMultiStreamGuard streamGuard(*stream);
813 std::vector<c10::Device> devList{dev};
814 work->future_ = c10::make_intrusive<at::ivalue::Future>(
815 c10::ListType::create(c10::TensorType::get()), devList);
816 // Add a callback that runs profiling end callbacks
817 if (work->recordFunctionEndCallback_) {
818 work->future_->addCallback([work](at::ivalue::Future& /* unused */) {
819 work->recordFunctionEndCallback_();
820 });
821 }
822
823 work->future_->markCompleted(c10::IValue(outputTensors));
824 }
825 return work;
826 }
827#endif // #ifdef USE_CUDA
828 default: {
829 TORCH_UCC_LOG_ERROR(
830 TORCH_UCC_COLL_POST, c10::str("unsupported device type ", dev.str()));
831 throw std::runtime_error(ucc_status_string(UCC_ERR_NOT_SUPPORTED));
832 }
833 }
834}
835
836c10::intrusive_ptr<Work> ProcessGroupUCC::allgather(
837 std::vector<std::vector<at::Tensor>>& outputTensors,
838 std::vector<at::Tensor>& inputTensors,
839 const AllgatherOptions& /* unused */) {
840 auto& tensor = inputTensors[0];
841 check_device(tensor.device(), outputTensors[0][0].device());
842 initComm(tensor.device());
843
844 if (tensor.device().is_cpu() || torch_ucc_config.use_allgatherv) {
845 AllgathervWorkData* data = new AllgathervWorkData(size_);
846 for (int i = 0; i < size_; i++) {
847 data->recv_lengths[i] = tensor.element_size() * tensor.numel();
848 data->recv_offsets[i] = (uint64_t)outputTensors[0][i].data_ptr();
849 }
850 ucc_coll_args_t coll;
851 coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
852 coll.flags =
853 UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT;
854 coll.coll_type = UCC_COLL_TYPE_ALLGATHERV;
855 coll.src.info.buffer = tensor.data_ptr();
856 coll.src.info.count = tensor.element_size() * tensor.numel();
857 coll.src.info.datatype = UCC_DT_UINT8;
858 coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
859 coll.dst.info_v.buffer = nullptr;
860 coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data();
861 coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data();
862 coll.dst.info_v.datatype = UCC_DT_UINT8;
863 coll.dst.info_v.mem_type =
864 to_ucc_memType(outputTensors[0][0].device().type());
865 SAVE_TENSORS(inputTensors, data->src);
866 SAVE_TENSORS(outputTensors[0], data->dst);
867
868 return collective_post(
869 OpType::ALLGATHER,
870 []() {},
871 []() {},
872 coll,
873 std::unique_ptr<WorkData>(data),
874 tensor.device(),
875 inputTensors,
876 outputTensors[0],
877 "ucc:all_gather");
878 } else {
879 WorkData* data = new WorkData();
880 std::vector<at::Tensor> flat_output(outputTensors.size());
881 for (size_t i = 0; i < outputTensors.size(); i++) {
882 TORCH_CHECK(
883 outputTensors[i].size() == outputTensors.size() * size_,
884 "Tensor output list is not valid for the number of participants");
885 flat_output[i] = c10d::newLikeFlat(outputTensors, i);
886 }
887 SAVE_TENSORS(flat_output, data->flat);
888 ucc_coll_args_t coll;
889 coll.mask = 0;
890 coll.flags = 0;
891 coll.coll_type = UCC_COLL_TYPE_ALLGATHER;
892 coll.src.info.buffer = tensor.data_ptr();
893 coll.src.info.count = tensor.numel();
894 coll.src.info.datatype = to_ucc_dType(tensor);
895 coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
896 coll.dst.info.buffer = flat_output[0].data_ptr();
897 coll.dst.info.count = flat_output[0].numel();
898 coll.dst.info.datatype = to_ucc_dType(flat_output[0]);
899 coll.dst.info.mem_type =
900 to_ucc_memType(outputTensors[0][0].device().type());
901
902 auto copy_from_flat = [&] {
903 bool asyncCopy = false;
904#ifdef USE_CUDA
905 bool isCuda = outputTensors[0][0].device().is_cuda();
906 ;
907#endif
908 for (size_t i = 0; i < outputTensors.size(); i++) {
909 auto inumel = inputTensors[i].numel();
910 for (size_t j = 0; j < outputTensors[i].size(); j++) {
911 TORCH_CHECK(
912 (outputTensors[i][j].numel() == inumel),
913 "Tensor operand counts must be same");
914#ifdef USE_CUDA
915 if (isCuda) {
916 c10::cuda::CUDACachingAllocator::recordStream(
917 outputTensors[i][j].storage().data_ptr(), (*stream));
918 asyncCopy = true;
919 }
920#endif
921 outputTensors[i][j].copy_(flat_output[i][j], asyncCopy);
922 }
923 }
924 };
925 return collective_post(
926 OpType::ALLGATHER,
927 []() {},
928 copy_from_flat,
929 coll,
930 std::unique_ptr<WorkData>(data),
931 tensor.device(),
932 inputTensors,
933 outputTensors[0],
934 "ucc:all_gather");
935 }
936}
937
938c10::intrusive_ptr<Work> ProcessGroupUCC::_allgather_base(
939 at::Tensor& outputTensor,
940 at::Tensor& inputTensor,
941 const AllgatherOptions& opts) {
942 check_tensor({outputTensor});
943 check_tensor({inputTensor});
944 initComm(outputTensor.device());
945
946 WorkData* data = new WorkData();
947
948 ucc_coll_args_t coll;
949 coll.mask = 0;
950 coll.flags = 0;
951 coll.coll_type = UCC_COLL_TYPE_ALLGATHER;
952 coll.src.info.buffer = inputTensor.data_ptr();
953 coll.src.info.count = inputTensor.numel();
954 coll.src.info.datatype = ucc_dtype_map.at(inputTensor.scalar_type());
955 coll.src.info.mem_type = to_ucc_memType(inputTensor.device().type());
956 coll.dst.info.buffer = outputTensor.data_ptr();
957 coll.dst.info.count = outputTensor.numel();
958 coll.dst.info.datatype = ucc_dtype_map.at(outputTensor.scalar_type());
959 coll.dst.info.mem_type = to_ucc_memType(outputTensor.device().type());
960
961 std::vector<at::Tensor> inputTensors = {inputTensor};
962 std::vector<at::Tensor> outputTensors = {outputTensor};
963 SAVE_TENSORS(inputTensors, data->src);
964 SAVE_TENSORS(outputTensors, data->dst);
965
966 return collective_post(
967 OpType::_ALLGATHER_BASE,
968 []() {},
969 []() {},
970 coll,
971 std::unique_ptr<WorkData>(data),
972 outputTensor.device(),
973 inputTensors,
974 outputTensors,
975 "ucc:allgather_base");
976}
977
978c10::intrusive_ptr<Work> ProcessGroupUCC::allreduce(
979 std::vector<at::Tensor>& tensors,
980 const AllreduceOptions& opts) {
981 check_tensor(tensors);
982 auto& tensor = tensors[0];
983 initComm(tensor.device());
984 WorkData* data = new WorkData();
985
986 ucc_coll_args_t coll;
987 coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
988 coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
989 coll.coll_type = UCC_COLL_TYPE_ALLREDUCE;
990 coll.op = to_ucc_reduceOp(opts.reduceOp, tensor.scalar_type());
991 coll.src.info.buffer = nullptr;
992 coll.src.info.count = tensor.numel();
993 coll.src.info.datatype = to_ucc_dType(tensor);
994 coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
995 coll.dst.info.buffer = tensor.data_ptr();
996 coll.dst.info.count = tensor.numel();
997 coll.dst.info.datatype = to_ucc_dType(tensor);
998 coll.dst.info.mem_type = to_ucc_memType(tensor.device().type());
999 SAVE_TENSORS(tensors, data->dst);
1000 return collective_post(
1001 OpType::ALLREDUCE,
1002 []() {},
1003 []() {},
1004 coll,
1005 std::unique_ptr<WorkData>(data),
1006 tensor.device(),
1007 tensors,
1008 tensors,
1009 "ucc:all_reduce");
1010}
1011
1012c10::intrusive_ptr<Work> ProcessGroupUCC::allreduce_coalesced(
1013 std::vector<at::Tensor>& /* unused */,
1014 const AllreduceCoalescedOptions& /* unused */) {
1015 throw std::runtime_error(
1016 "ProcessGroupUCC does not support allreduce_coalesced");
1017}
1018
1019c10::intrusive_ptr<Work> ProcessGroupUCC::alltoall(
1020 std::vector<at::Tensor>& outputTensors,
1021 std::vector<at::Tensor>& inputTensors,
1022 const AllToAllOptions& /* unused */) {
1023 auto device = outputTensors[0].device();
1024 for (const auto r : c10::irange(outputTensors.size())) {
1025 TORCH_CHECK(
1026 device == outputTensors[r].device() &&
1027 device == inputTensors[r].device(),
1028 "Tensors must be on the same device")
1029 }
1030
1031 initComm(device);
1032 ucc_coll_args_t coll;
1033 AlltoallWorkData* data;
1034 data = new AlltoallWorkData(size_);
1035
1036 /* to avoid flatten the tensors, we use alltoallv to achieve Alltoall as
1037 follow.
1038 1. store addresses of each tensor directly in displacements, keep buffer
1039 to nullptr, i.e., 0
1040 2. convert datatype to UINT8, which is always 1 bytes, to avoid wrong size
1041 calculation in UCC layer
1042 3. post Alltoallv
1043 */
1044 for (const auto i : c10::irange(size_)) {
1045 data->send_lengths[i] =
1046 (uint64_t)(inputTensors[i].element_size() * inputTensors[i].numel());
1047 data->send_offsets[i] = (uint64_t)inputTensors[i].data_ptr();
1048 data->recv_lengths[i] =
1049 (uint64_t)(outputTensors[i].element_size() * outputTensors[i].numel());
1050 data->recv_offsets[i] = (uint64_t)outputTensors[i].data_ptr();
1051 }
1052
1053 coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
1054 coll.flags =
1055 UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT;
1056 coll.coll_type = UCC_COLL_TYPE_ALLTOALLV;
1057 coll.src.info_v.buffer = 0;
1058 coll.src.info_v.counts = (ucc_count_t*)data->send_lengths.data();
1059 coll.src.info_v.displacements = (ucc_aint_t*)data->send_offsets.data();
1060 coll.src.info_v.datatype = UCC_DT_UINT8;
1061 coll.src.info_v.mem_type = to_ucc_memType(inputTensors[0].device().type());
1062 coll.dst.info_v.buffer = 0;
1063 coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data();
1064 coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data();
1065 coll.dst.info_v.datatype = UCC_DT_UINT8;
1066 coll.dst.info_v.mem_type = to_ucc_memType(outputTensors[0].device().type());
1067
1068 SAVE_TENSORS(inputTensors, data->src);
1069 SAVE_TENSORS(outputTensors, data->dst);
1070
1071 return collective_post(
1072 OpType::ALLTOALL,
1073 []() {},
1074 []() {},
1075 coll,
1076 std::unique_ptr<WorkData>(data),
1077 device,
1078 inputTensors,
1079 outputTensors,
1080 "ucc:alltoall");
1081}
1082
1083c10::intrusive_ptr<Work> ProcessGroupUCC::alltoall_base(
1084 at::Tensor& outputTensor,
1085 at::Tensor& inputTensor,
1086 std::vector<int64_t>& outputSplitSizes,
1087 std::vector<int64_t>& inputSplitSizes,
1088 const AllToAllOptions& /* unused */) {
1089 check_device(inputTensor.device(), outputTensor.device());
1090 initComm(inputTensor.device());
1091 ucc_coll_args_t coll;
1092 AlltoallWorkData* data;
1093
1094 if ((outputSplitSizes.size() == 0) && (inputSplitSizes.size() == 0)) {
1095 data = new AlltoallWorkData(0);
1096 TORCH_CHECK(
1097 (outputTensor.size(0) % size_ == 0) &&
1098 (inputTensor.size(0) % size_ == 0),
1099 "Tensor's dim 0 does not divide equally across group size");
1100 coll.mask = 0;
1101 coll.flags = 0;
1102 coll.coll_type = UCC_COLL_TYPE_ALLTOALL;
1103 coll.src.info.buffer = inputTensor.data_ptr();
1104 coll.src.info.count = inputTensor.element_size() * inputTensor.numel();
1105 coll.src.info.datatype = UCC_DT_UINT8;
1106 coll.src.info.mem_type = to_ucc_memType(inputTensor.device().type());
1107 coll.dst.info.buffer = outputTensor.data_ptr();
1108 coll.dst.info.count = outputTensor.element_size() * outputTensor.numel();
1109 coll.dst.info.datatype = UCC_DT_UINT8;
1110 coll.dst.info.mem_type = to_ucc_memType(outputTensor.device().type());
1111 coll.flags = 0;
1112 } else {
1113 data = new AlltoallWorkData(size_);
1114 c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_);
1115 c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_);
1116 computeLengthsAndOffsets(
1117 outputSplitSizes,
1118 outputTensor,
1119 &data->recv_lengths,
1120 &data->recv_offsets);
1121 computeLengthsAndOffsets(
1122 inputSplitSizes, inputTensor, &data->send_lengths, &data->send_offsets);
1123 coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
1124 coll.coll_type = UCC_COLL_TYPE_ALLTOALLV;
1125 coll.src.info_v.buffer = inputTensor.data_ptr();
1126 coll.src.info_v.counts = (ucc_count_t*)data->send_lengths.data();
1127 coll.src.info_v.displacements = (ucc_aint_t*)data->send_offsets.data();
1128 coll.src.info_v.datatype = to_ucc_dType(inputTensor);
1129 coll.src.info_v.mem_type = to_ucc_memType(inputTensor.device().type());
1130 coll.dst.info_v.buffer = outputTensor.data_ptr();
1131 coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data();
1132 coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data();
1133 coll.dst.info_v.datatype = to_ucc_dType(outputTensor);
1134 coll.dst.info_v.mem_type = to_ucc_memType(outputTensor.device().type());
1135 coll.flags = UCC_COLL_ARGS_FLAG_CONTIG_SRC_BUFFER |
1136 UCC_COLL_ARGS_FLAG_CONTIG_DST_BUFFER | UCC_COLL_ARGS_FLAG_COUNT_64BIT |
1137 UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT;
1138
1139 if (torch_ucc_config.enable_comms_logger) {
1140 logger->trace_generator->recordOptionalInfo(
1141 outputSplitSizes, inputSplitSizes);
1142 }
1143 }
1144 std::vector<at::Tensor> inputTensors = {inputTensor};
1145 std::vector<at::Tensor> outputTensors = {outputTensor};
1146 SAVE_TENSORS(inputTensors, data->src);
1147 SAVE_TENSORS(outputTensors, data->dst);
1148
1149 return collective_post(
1150 OpType::ALLTOALL_BASE,
1151 []() {},
1152 []() {},
1153 coll,
1154 std::unique_ptr<WorkData>(data),
1155 inputTensor.device(),
1156 inputTensors,
1157 outputTensors,
1158 "ucc:alltoall");
1159}
1160
1161c10::intrusive_ptr<Work> ProcessGroupUCC::barrier(const BarrierOptions& opts) {
1162 c10::Device device = c10::Device(c10::DeviceType::CPU);
1163#ifdef USE_CUDA
1164 auto numGPUs = c10::cuda::device_count();
1165 if (!opts.device_ids.empty()) {
1166 device = c10::Device(c10::DeviceType::CUDA, opts.device_ids.front());
1167 } else if (comm && comm->cuda_device_index != TORCH_UCC_DEVICE_NOT_SET) {
1168 device = c10::Device(c10::DeviceType::CUDA, comm->cuda_device_index);
1169 } else if (numGPUs > 0) {
1170 int8_t deviceIdx = static_cast<int8_t>(c10::cuda::current_device());
1171 // if current device is 0, likely the device is not set, use the best guess
1172 if (0 == (int)deviceIdx) {
1173 deviceIdx = static_cast<int8_t>(this->getRank() % numGPUs);
1174 }
1175 TORCH_UCC_LOG_INFO(
1176 TORCH_UCC_COLL_POST,
1177 c10::str(
1178 "post barrier before specifying any GPU while there are ",
1179 numGPUs,
1180 " GPUs available. ",
1181 "Not clear if GPU barrier is required, using GPU ",
1182 (int)deviceIdx,
1183 " to perform barrier. ",
1184 "Specify device_ids option in barrier() to force ",
1185 "use of a particular device"));
1186 device = c10::Device(c10::DeviceType::CUDA, deviceIdx);
1187 }
1188#endif
1189 initComm(device);
1190
1191 ucc_coll_args_t coll;
1192 coll.mask = 0;
1193 coll.flags = 0;
1194 coll.coll_type = UCC_COLL_TYPE_BARRIER;
1195 auto dummy_tensor = std::vector<at::Tensor>();
1196 return collective_post(
1197 OpType::BARRIER,
1198 []() {},
1199 []() {},
1200 coll,
1201 nullptr,
1202 device,
1203 dummy_tensor,
1204 dummy_tensor,
1205 "ucc:barrier");
1206}
1207
1208c10::intrusive_ptr<Work> ProcessGroupUCC::broadcast(
1209 std::vector<at::Tensor>& tensors,
1210 const BroadcastOptions& opts) {
1211 check_tensor(tensors);
1212 auto& tensor = tensors[0];
1213 initComm(tensor.device());
1214 WorkData* data = new WorkData();
1215
1216 ucc_coll_args_t coll;
1217 coll.mask = 0;
1218 coll.flags = 0;
1219 coll.coll_type = UCC_COLL_TYPE_BCAST;
1220 coll.src.info.buffer = tensor.data_ptr();
1221 coll.src.info.count = tensor.numel();
1222 coll.src.info.datatype = to_ucc_dType(tensor);
1223 coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
1224 coll.root = opts.rootRank;
1225 SAVE_TENSORS(tensors, data->dst);
1226
1227 if (torch_ucc_config.enable_comms_logger) {
1228 logger->trace_generator->recordOptionalInfo(opts.rootRank);
1229 }
1230
1231 return collective_post(
1232 OpType::BROADCAST,
1233 []() {},
1234 []() {},
1235 coll,
1236 std::unique_ptr<WorkData>(data),
1237 tensor.device(),
1238 tensors,
1239 tensors,
1240 "ucc:broadcast");
1241}
1242
1243c10::intrusive_ptr<Work> ProcessGroupUCC::gather(
1244 std::vector<std::vector<at::Tensor>>& outputTensors,
1245 std::vector<at::Tensor>& inputTensors,
1246 const GatherOptions& opts) {
1247 std::vector<at::Tensor> outputs;
1248 auto& input = inputTensors[0];
1249 initComm(input.device());
1250
1251 AllgathervWorkData* data = new AllgathervWorkData(size_);
1252 ucc_coll_args_t coll;
1253 coll.root = opts.rootRank;
1254 coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
1255 coll.flags =
1256 UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT;
1257 coll.coll_type = UCC_COLL_TYPE_GATHERV;
1258
1259 /* for non-root ranks, only src is valid */
1260 coll.src.info.buffer = input.data_ptr();
1261 coll.src.info.count = (uint64_t)(input.element_size() * input.numel());
1262 coll.src.info.datatype = UCC_DT_UINT8;
1263 coll.src.info.mem_type = to_ucc_memType(input.device().type());
1264
1265 if (getRank() == opts.rootRank) {
1266 if (outputTensors.size() != 1) {
1267 TORCH_UCC_LOG_ERROR(
1268 TORCH_UCC_COLL_POST,
1269 c10::str(
1270 "gather requires a single-element output list containing a list with ",
1271 getSize(),
1272 " tensors."));
1273 } else if (outputTensors[0].size() != static_cast<size_t>(getSize())) {
1274 TORCH_UCC_LOG_ERROR(
1275 TORCH_UCC_COLL_POST,
1276 c10::str(
1277 "Incorrect output list size ",
1278 outputTensors[0].size(),
1279 ". Output list size should be ",
1280 getSize(),
1281 ", same as size of the process group."));
1282 }
1283 outputs = outputTensors[0];
1284
1285 for (int i = 0; i < size_; i++) {
1286 data->recv_lengths[i] =
1287 (uint64_t)(outputs[i].element_size() * outputs[i].numel());
1288 data->recv_offsets[i] = (uint64_t)outputs[i].data_ptr();
1289 }
1290 /* use gatherv and store non-contiguous addresses in displacements to avoid
1291 * flatten outputTensors */
1292 coll.dst.info_v.buffer = nullptr;
1293 coll.dst.info_v.counts = (ucc_count_t*)data->recv_lengths.data();
1294 coll.dst.info_v.displacements = (ucc_aint_t*)data->recv_offsets.data();
1295 coll.dst.info_v.datatype = UCC_DT_UINT8;
1296 coll.dst.info_v.mem_type = to_ucc_memType(outputs[0].device().type());
1297
1298 SAVE_TENSORS(outputs, data->dst);
1299 } else {
1300 // for non-root ranks, outputTensors should be an empty list
1301 if (outputTensors.size() != 0) {
1302 TORCH_UCC_LOG_ERROR(
1303 TORCH_UCC_COLL_POST, "requires empty output on non-root");
1304 }
1305 outputs = {};
1306 // append a empty tensor to the list to be used by future mark
1307 outputs.emplace_back();
1308 }
1309
1310 SAVE_TENSORS(inputTensors, data->src);
1311
1312 return collective_post(
1313 OpType::GATHER,
1314 []() {},
1315 []() {},
1316 coll,
1317 std::unique_ptr<WorkData>(data),
1318 input.device(),
1319 inputTensors,
1320 outputs,
1321 "ucc:gather");
1322}
1323
1324c10::intrusive_ptr<Work> ProcessGroupUCC::reduce(
1325 std::vector<at::Tensor>& tensors,
1326 const ReduceOptions& opts) {
1327 check_tensor(tensors);
1328 auto& tensor = tensors[0];
1329 initComm(tensor.device());
1330 WorkData* data = new WorkData();
1331
1332 ucc_coll_args_t coll;
1333 coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
1334 coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
1335 coll.coll_type = UCC_COLL_TYPE_REDUCE;
1336 coll.op = ucc_op_map.at(opts.reduceOp);
1337 coll.root = opts.rootRank;
1338 coll.src.info.buffer = tensor.data_ptr();
1339 coll.src.info.count = tensor.numel();
1340 coll.src.info.datatype = ucc_dtype_map.at(tensor.scalar_type());
1341 coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
1342 coll.dst.info.buffer = tensor.data_ptr();
1343 coll.dst.info.count = tensor.numel();
1344 coll.dst.info.datatype = ucc_dtype_map.at(tensor.scalar_type());
1345 coll.dst.info.mem_type = to_ucc_memType(tensor.device().type());
1346 SAVE_TENSORS(tensors, data->dst);
1347 return collective_post(
1348 OpType::REDUCE,
1349 []() {},
1350 []() {},
1351 coll,
1352 std::unique_ptr<WorkData>(data),
1353 tensor.device(),
1354 tensors,
1355 tensors,
1356 "ucc:reduce");
1357}
1358
1359c10::intrusive_ptr<Work> ProcessGroupUCC::reduce_scatter(
1360 std::vector<at::Tensor>& outputTensors,
1361 std::vector<std::vector<at::Tensor>>& inputTensors,
1362 const ReduceScatterOptions& opts) {
1363 TORCH_CHECK(
1364 (outputTensors.size() == inputTensors.size()),
1365 "Tensor input/output list for reduce_scatter must have same size");
1366 check_tensor(outputTensors);
1367 check_device(inputTensors[0][0].device(), outputTensors[0].device());
1368 initComm(inputTensors[0][0].device());
1369 auto data = std::make_unique<WorkData>();
1370 std::vector<at::Tensor> flat_input(inputTensors.size());
1371 for (size_t i = 0; i < inputTensors.size(); i++) {
1372 TORCH_CHECK(
1373 inputTensors[i].size() == inputTensors.size() * size_,
1374 "Tensor input list is not valid for the number of participants");
1375 flat_input[i] = c10d::newLikeFlat(inputTensors, i);
1376 }
1377 SAVE_TENSORS(flat_input, data->flat);
1378 check_tensor(flat_input);
1379 ucc_coll_args_t coll;
1380 coll.mask = 0;
1381 coll.flags = 0;
1382 coll.coll_type = UCC_COLL_TYPE_REDUCE_SCATTER;
1383 coll.op = to_ucc_reduceOp(opts.reduceOp, flat_input[0].scalar_type());
1384
1385 coll.src.info.buffer = flat_input[0].data_ptr();
1386 coll.src.info.count = flat_input[0].numel();
1387 coll.src.info.datatype = to_ucc_dType(flat_input[0]);
1388 coll.src.info.mem_type = to_ucc_memType(flat_input[0].device().type());
1389 coll.dst.info.buffer = outputTensors[0].data_ptr();
1390 coll.dst.info.count = outputTensors[0].numel();
1391 coll.dst.info.datatype = to_ucc_dType(outputTensors[0]);
1392 coll.dst.info.mem_type = to_ucc_memType(outputTensors[0].device().type());
1393
1394 SAVE_TENSORS(inputTensors[0], data->src);
1395 SAVE_TENSORS(outputTensors, data->dst);
1396
1397 auto copy_to_flat = [&] {
1398 bool asyncCopy = false;
1399 auto isize = inputTensors.size();
1400#ifdef USE_CUDA
1401 bool isCuda = inputTensors[0][0].device().is_cuda();
1402#endif
1403 for (size_t i = 0; i < isize; i++) {
1404 auto onumel = outputTensors[i].numel();
1405 for (size_t j = 0; j < inputTensors[i].size(); j++) {
1406 TORCH_CHECK(
1407 (inputTensors[i][j].numel() == onumel),
1408 "Tensor operand counts must be same");
1409#ifdef USE_CUDA
1410 if (isCuda) {
1411 c10::cuda::CUDACachingAllocator::recordStream(
1412 inputTensors[i][j].storage().data_ptr(), (*stream));
1413 asyncCopy = true;
1414 }
1415#endif
1416 flat_input[i][j].copy_(inputTensors[i][j], asyncCopy);
1417 }
1418 }
1419 };
1420
1421 return collective_post(
1422 OpType::REDUCE_SCATTER,
1423 copy_to_flat,
1424 []() {},
1425 coll,
1426 std::move(data),
1427 inputTensors[0][0].device(),
1428 inputTensors[0],
1429 outputTensors,
1430 "ucc:reduce_scatter");
1431}
1432
1433c10::intrusive_ptr<Work> ProcessGroupUCC::scatter(
1434 std::vector<at::Tensor>& outputTensors,
1435 std::vector<std::vector<at::Tensor>>& inputTensors,
1436 const ScatterOptions& opts) {
1437 auto& tensor = outputTensors[0];
1438 initComm(tensor.device());
1439
1440 ScattervWorkData* data = new ScattervWorkData(size_);
1441 ucc_coll_args_t coll;
1442 coll.root = opts.rootRank;
1443 coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
1444 coll.flags =
1445 UCC_COLL_ARGS_FLAG_COUNT_64BIT | UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT;
1446 coll.coll_type = UCC_COLL_TYPE_SCATTERV;
1447
1448 if (getRank() == opts.rootRank) {
1449 /* src is only valid at non-root rank */
1450 if (inputTensors.size() != 1) {
1451 TORCH_UCC_LOG_ERROR(
1452 TORCH_UCC_COLL_POST,
1453 c10::str(
1454 "gather requires a single-element output list containing a list with ",
1455 getSize(),
1456 " tensors."));
1457 } else if (inputTensors[0].size() != static_cast<size_t>(getSize())) {
1458 TORCH_UCC_LOG_ERROR(
1459 TORCH_UCC_COLL_POST,
1460 c10::str(
1461 "Incorrect output list size ",
1462 inputTensors[0].size(),
1463 ". Output list size should be ",
1464 getSize(),
1465 ", same as size of the process group."));
1466 }
1467
1468 for (int i = 0; i < size_; i++) {
1469 data->send_lengths[i] = (uint64_t)tensor.element_size() * tensor.numel();
1470 data->send_offsets[i] = (uint64_t)inputTensors[0][i].data_ptr();
1471 }
1472 /* use scatter and store non-contiguous addresses in displacements to avoid
1473 * flatten inputTensors */
1474 coll.src.info_v.buffer = nullptr;
1475 coll.src.info_v.counts = (ucc_count_t*)data->send_lengths.data();
1476 coll.src.info_v.displacements = (ucc_aint_t*)data->send_offsets.data();
1477 coll.src.info_v.datatype = UCC_DT_UINT8;
1478 coll.src.info_v.mem_type =
1479 to_ucc_memType(inputTensors[0][0].device().type());
1480
1481 SAVE_TENSORS(inputTensors[0], data->src);
1482 } else {
1483 // for non-root ranks, inputTensors should be an empty list
1484 if (inputTensors.size() != 0) {
1485 TORCH_UCC_LOG_ERROR(
1486 TORCH_UCC_COLL_POST, "requires empty output on non-root");
1487 }
1488 }
1489
1490 coll.dst.info.buffer = tensor.data_ptr();
1491 coll.dst.info.count = (uint64_t)tensor.element_size() * tensor.numel();
1492 coll.dst.info.datatype = UCC_DT_UINT8;
1493 coll.dst.info.mem_type = to_ucc_memType(tensor.device().type());
1494 SAVE_TENSORS(outputTensors, data->dst);
1495
1496 return collective_post(
1497 OpType::SCATTER,
1498 []() {},
1499 []() {},
1500 coll,
1501 std::unique_ptr<WorkData>(data),
1502 tensor.device(),
1503 inputTensors[0],
1504 outputTensors,
1505 "ucc:scatter");
1506}
1507
1508c10::intrusive_ptr<Work> ProcessGroupUCC::send(
1509 std::vector<at::Tensor>& tensors,
1510 int dstRank,
1511 int tag) {
1512 check_tensor(tensors);
1513 auto& tensor = tensors[0];
1514 initComm(tensor.device());
1515
1516 WorkData* data = new WorkData();
1517 ucc_coll_args_t coll;
1518 coll.tag = tag;
1519 coll.mask = UCC_COLL_ARGS_FIELD_ACTIVE_SET | UCC_COLL_ARGS_FIELD_TAG;
1520 coll.flags = 0;
1521 coll.coll_type = UCC_COLL_TYPE_BCAST;
1522 coll.src.info.buffer = tensor.data_ptr();
1523 coll.src.info.count = tensor.numel();
1524 coll.src.info.datatype = to_ucc_dType(tensor);
1525 coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
1526 coll.root = getRank();
1527
1528 coll.active_set.size = 2;
1529 coll.active_set.start = getRank();
1530 coll.active_set.stride = dstRank - getRank();
1531 SAVE_TENSORS(tensors, data->dst);
1532
1533 return collective_post(
1534 OpType::SEND,
1535 []() {},
1536 []() {},
1537 coll,
1538 std::unique_ptr<WorkData>(data),
1539 tensor.device(),
1540 tensors,
1541 tensors,
1542 "ucc:send");
1543}
1544
1545c10::intrusive_ptr<Work> ProcessGroupUCC::recv(
1546 std::vector<at::Tensor>& tensors,
1547 int srcRank,
1548 int tag) {
1549 check_tensor(tensors);
1550 auto& tensor = tensors[0];
1551 initComm(tensor.device());
1552
1553 WorkData* data = new WorkData();
1554 ucc_coll_args_t coll;
1555 coll.tag = tag;
1556 coll.mask = UCC_COLL_ARGS_FIELD_ACTIVE_SET | UCC_COLL_ARGS_FIELD_TAG;
1557 coll.flags = 0;
1558 coll.coll_type = UCC_COLL_TYPE_BCAST;
1559 coll.src.info.buffer = tensor.data_ptr();
1560 coll.src.info.count = tensor.numel();
1561 coll.src.info.datatype = to_ucc_dType(tensor);
1562 coll.src.info.mem_type = to_ucc_memType(tensor.device().type());
1563 coll.root = srcRank;
1564
1565 coll.active_set.size = 2;
1566 coll.active_set.start = srcRank;
1567 coll.active_set.stride = getRank() - srcRank;
1568 SAVE_TENSORS(tensors, data->dst);
1569
1570 return collective_post(
1571 OpType::RECV,
1572 []() {},
1573 []() {},
1574 coll,
1575 std::unique_ptr<WorkData>(data),
1576 tensor.device(),
1577 tensors,
1578 tensors,
1579 "ucc:recv");
1580}
1581
1582void ProcessGroupUCC::setSequenceNumberForGroup() {}
1583
1584uint64_t ProcessGroupUCC::getSequenceNumberForGroup() {
1585 return seq_;
1586}
1587
1588c10::intrusive_ptr<Backend> ProcessGroupUCC::createProcessGroupUCC(
1589 const c10::intrusive_ptr<::c10d::Store>& store,
1590 int rank,
1591 int size,
1592 const std::chrono::duration<float>& timeout) {
1593 return c10::make_intrusive<ProcessGroupUCC>(store, rank, size, timeout);
1594}
1595
1596void ProcessGroupUCC::initComm(c10::Device dev) {
1597 if (!comm) {
1598#ifdef USE_CUDA
1599 if (dev.is_cuda()) {
1600 c10::cuda::set_device(dev.index());
1601 }
1602#endif
1603 comm = Comm::get_comm(comm_id, dev, oob, logger);
1604 TORCH_UCC_LOG_INFO(TORCH_UCC_INIT, "Successfully initialized UCX library");
1605 comm->ucc_create_team(team, oob);
1606 TORCH_UCC_LOG_INFO(TORCH_UCC_INIT, "Successfully initialized UCC library");
1607 logger->setPhase(TORCH_UCC_READY);
1608 } else {
1609 if (dev.is_cuda()) {
1610 if ((comm->cuda_device_index != TORCH_UCC_DEVICE_NOT_SET) &&
1611 (comm->cuda_device_index != dev.index())) {
1612 TORCH_UCC_LOG_ERROR(
1613 TORCH_UCC_INIT,
1614 "ucc communicator was initialized with different cuda device,"
1615 "multi device is not supported");
1616 throw std::runtime_error(ucc_status_string(UCC_ERR_NOT_SUPPORTED));
1617 }
1618 comm->cuda_device_index = dev.index();
1619 }
1620 }
1621#ifdef USE_CUDA
1622 // Create UCC execution engine.
1623 if (!cuda_ee && dev.is_cuda()) {
1624 stream = std::make_unique<at::cuda::CUDAStream>(
1625 at::cuda::getStreamFromPool(true, dev.index()));
1626 ucc_ee_params_t params;
1627 params.ee_type = UCC_EE_CUDA_STREAM;
1628 params.ee_context = (void*)stream->stream();
1629 params.ee_context_size = sizeof(cudaStream_t);
1630 TORCH_UCC_CHECK(
1631 ucc_ee_create(team, &params, &cuda_ee),
1632 "failed to create UCC execution engine");
1633 }
1634#endif
1635}
1636
1637} // namespace c10d
1638
1639#endif // USE_C10D_UCC
1640