1#include <ATen/core/functional.h>
2#include <torch/csrc/cuda/device_set.h>
3#include <torch/csrc/cuda/nccl.h>
4
5#include <ATen/ATen.h>
6#include <c10/cuda/CUDAException.h>
7#include <c10/cuda/CUDAGuard.h>
8#include <c10/util/Exception.h>
9#include <c10/util/hash.h>
10#include <c10/util/irange.h>
11
12#include <nccl.h>
13
14#include <limits>
15#include <sstream>
16#include <type_traits>
17#include <unordered_map>
18
19ncclComm_t* to_nccl_comm(torch::cuda::nccl::ncclComm_t* var) {
20 return reinterpret_cast<ncclComm_t*>(var);
21}
22
23ncclComm_t to_nccl_comm(torch::cuda::nccl::ncclComm_t var) {
24 return reinterpret_cast<ncclComm_t>(var);
25}
26
27ncclUniqueId* to_nccl_unique_id(torch::cuda::nccl::ncclUniqueId* var) {
28 return reinterpret_cast<ncclUniqueId*>(var);
29}
30
31ncclResult_t to_nccl_result(torch::cuda::nccl::ncclResult var) {
32 switch (var) {
33 case torch::cuda::nccl::ncclResult::Success:
34 return ncclResult_t::ncclSuccess;
35 case torch::cuda::nccl::ncclResult::UnhandledCudaError:
36 return ncclResult_t::ncclUnhandledCudaError;
37 case torch::cuda::nccl::ncclResult::SystemError:
38 return ncclResult_t::ncclSystemError;
39 case torch::cuda::nccl::ncclResult::InternalError:
40 return ncclResult_t::ncclInternalError;
41 case torch::cuda::nccl::ncclResult::InvalidArgument:
42 return ncclResult_t::ncclInvalidArgument;
43 case torch::cuda::nccl::ncclResult::InvalidUsage:
44 return ncclResult_t::ncclInvalidUsage;
45 case torch::cuda::nccl::ncclResult::NumResults:
46 return ncclResult_t::ncclNumResults;
47 default:
48 throw std::runtime_error("Unconvertible NCCL type");
49 }
50}
51
52torch::cuda::nccl::ncclResult from_nccl_result(ncclResult_t var) {
53 switch (var) {
54 case ncclSuccess:
55 return torch::cuda::nccl::ncclResult::Success;
56 case ncclUnhandledCudaError:
57 return torch::cuda::nccl::ncclResult::UnhandledCudaError;
58 case ncclSystemError:
59 return torch::cuda::nccl::ncclResult::SystemError;
60 case ncclInternalError:
61 return torch::cuda::nccl::ncclResult::InternalError;
62 case ncclInvalidArgument:
63 return torch::cuda::nccl::ncclResult::InvalidArgument;
64 case ncclInvalidUsage:
65 return torch::cuda::nccl::ncclResult::InvalidUsage;
66 case ncclNumResults:
67 return torch::cuda::nccl::ncclResult::NumResults;
68 default:
69 throw std::runtime_error("Unconvertible NCCL type");
70 }
71}
72
73ncclDataType_t to_nccl_data_type(c10::ScalarType type) {
74 switch (type) {
75 case at::kFloat:
76 return ncclDataType_t::ncclFloat;
77 case at::kHalf:
78 return ncclDataType_t::ncclHalf;
79 case at::kDouble:
80 return ncclDataType_t::ncclDouble;
81 case at::kLong:
82 return ncclDataType_t::ncclInt64;
83 case at::kInt:
84 return ncclDataType_t::ncclInt;
85 case at::kChar:
86 return ncclDataType_t::ncclChar;
87 case at::kByte:
88 return ncclDataType_t::ncclUint8;
89 case at::kBool:
90 return ncclDataType_t::ncclUint8;
91#if HAS_NCCL_BF16_DATATYPE
92 case at::kBFloat16:
93 return ncclDataType_t::ncclBfloat16;
94#endif
95 default:
96 TORCH_CHECK(false, "Unconvertible NCCL type ", type);
97 }
98}
99
100ncclDataType_t to_nccl_data_type(const at::Tensor& t) {
101 if (!t.is_cuda()) {
102 TORCH_CHECK(
103 false,
104 "NCCL only supports CUDA tensors, but got a tensor on ",
105 t.device());
106 }
107 return to_nccl_data_type(t.scalar_type());
108}
109
110ncclRedOp_t to_nccl_red_op(int var) {
111 return (ncclRedOp_t)(var);
112}
113
114namespace torch {
115namespace cuda {
116namespace nccl {
117
118using namespace at;
119
120namespace detail {
121
122static inline void NCCL_CHECK(ncclResult_t result) {
123 NCCL_CHECK(from_nccl_result(result));
124}
125
126void throw_nccl_error(torch::cuda::nccl::ncclResult status) {
127 std::ostringstream err;
128 err << "NCCL Error " << static_cast<int>(status) << ": "
129 << ncclGetErrorString(to_nccl_result(status));
130 throw std::runtime_error(err.str());
131}
132
133struct NcclCommList {
134 std::unique_ptr<ncclComm_t[]> comms;
135 int ndevices;
136 NcclCommList(const std::vector<int>& devices)
137 : comms(new ncclComm_t[devices.size()]), ndevices(devices.size()) {
138 NCCL_CHECK(ncclCommInitAll(
139 to_nccl_comm(comms.get()), devices.size(), devices.data()));
140 }
141 NcclCommList(NcclCommList&& foo) = default;
142 ~NcclCommList() {
143 if (comms) {
144 for (const auto i : c10::irange(ndevices)) {
145 int dummy_var;
146 if (C10_CUDA_ERROR_HANDLED(cudaGetDevice(&dummy_var)) != cudaSuccess) {
147 /* there are cases when this destructor is called after the
148 CUDA driver is already unloaded from the process.
149 In these cases, skip ncclCommDestroy */
150 return;
151 }
152 comm_destroy(comms[i]);
153 }
154 }
155 }
156 ArrayRef<ncclComm_t> ref() const {
157 return ArrayRef<ncclComm_t>(comms.get(), ndevices);
158 }
159};
160
161using device_list = std::vector<int>;
162// accesses to this object have to be guarded by THC's CudaFreeMutex
163static std::unordered_map<device_list, NcclCommList, c10::hash<device_list>>
164 _communicators;
165
166ArrayRef<ncclComm_t> get_communicators(TensorList inputs) {
167 static auto get_device = [](const at::Tensor& t) -> int {
168 return t.get_device();
169 };
170 device_list devices = fmap(inputs, get_device);
171 auto it = _communicators.find(devices);
172 if (it == _communicators.end())
173 std::tie(it, std::ignore) = _communicators.emplace(devices, devices);
174 return it->second.ref();
175}
176
177static inline void check_tensor(
178 const at::Tensor& input,
179 const at::optional<at::Tensor>& output,
180 int input_multiplier,
181 int output_multiplier,
182 int64_t ref_numel,
183 ScalarType ref_dtype) {
184 auto check_one = [&](const at::Tensor& tensor) {
185 if (!tensor.is_cuda() || tensor.is_sparse()) {
186 throw std::runtime_error(
187 "input and output elements have to be cuda dense Tensors");
188 }
189
190 if (ref_dtype != tensor.scalar_type()) {
191 throw std::runtime_error(
192 "all inputs and outputs must be of the same Tensor dtype");
193 }
194
195 if (!tensor.is_contiguous()) {
196 throw std::runtime_error("all inputs and outputs have to be contiguous");
197 }
198 };
199
200 check_one(input);
201
202 // all inputs must be same size
203 if (input.numel() != ref_numel) {
204 throw std::runtime_error(
205 "all inputs must have the same number of elements");
206 }
207
208 if (output) {
209 check_one(*output);
210
211 // inputs and outputs must be on same device respectively
212 if (input.get_device() != output->get_device()) {
213 throw std::runtime_error("input and output must be on the same device");
214 }
215
216 if (output->numel() * output_multiplier != ref_numel * input_multiplier) {
217 throw std::runtime_error(
218 "output must be of size input_size * size_multiplier");
219 }
220 }
221}
222
223void check_inputs(
224 TensorList inputs,
225 TensorList outputs,
226 int input_multiplier,
227 int output_multiplier) {
228 // len(inputs) == len(outputs)
229 size_t len = inputs.size();
230
231 if (len <= 0) {
232 throw std::runtime_error("input sequence can't be empty");
233 }
234
235 if (len != outputs.size()) {
236 std::stringstream err;
237 err << "inputs and outputs sequences have to be of the same length, but got input of length "
238 << len << " and output of length " << outputs.size();
239 throw std::runtime_error(err.str());
240 }
241
242 device_set devices;
243 int64_t numel = inputs[0].numel();
244 auto dtype = inputs[0].scalar_type();
245
246 for (const auto i : c10::irange(len)) {
247 auto input = inputs[i];
248 auto output = outputs[i];
249
250 check_tensor(
251 input, output, input_multiplier, output_multiplier, numel, dtype);
252
253 auto input_device = input.get_device();
254 // inputs must be on unique devices
255 if (devices.test(input_device)) {
256 throw std::runtime_error("inputs must be on unique devices");
257 }
258 devices.set(input_device);
259 }
260}
261
262void check_inputs(
263 TensorList inputs,
264 const at::Tensor& output,
265 int root,
266 int input_multiplier,
267 int output_multiplier) {
268 size_t len = inputs.size();
269
270 if (len <= 0) {
271 throw std::runtime_error("input sequence can't be empty");
272 }
273
274 device_set devices;
275 int64_t numel = inputs[0].numel();
276 auto dtype = inputs[0].scalar_type();
277
278 for (const auto i : c10::irange(len)) {
279 auto input = inputs[i];
280
281 check_tensor(
282 input,
283 i == root ? at::optional<at::Tensor>{output} : at::nullopt,
284 input_multiplier,
285 output_multiplier,
286 numel,
287 dtype);
288
289 auto input_device = input.get_device();
290 // inputs must be on unique devices
291 if (devices.test(input_device)) {
292 throw std::runtime_error("inputs must be on unique devices");
293 }
294 devices.set(input_device);
295 }
296}
297
298} // namespace detail
299
300AutoNcclGroup::AutoNcclGroup() {
301 (c10::cuda::getFreeMutex())->lock();
302#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
303 detail::NCCL_CHECK(ncclGroupStart());
304#endif
305}
306
307AutoNcclGroup::~AutoNcclGroup() noexcept(false) {
308#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
309 detail::NCCL_CHECK(ncclGroupEnd());
310#endif
311 (c10::cuda::getFreeMutex())->unlock();
312}
313
314bool is_available(TensorList tensors) {
315#ifdef USE_NCCL
316 device_set devices;
317 for (auto& tensor : tensors) {
318 if (!tensor.is_cuda() || tensor.is_sparse())
319 return false;
320 if (!tensor.is_contiguous())
321 return false;
322 auto device = tensor.get_device();
323 if (devices[device])
324 return false;
325 devices[device] = true;
326 }
327 return true;
328#else
329 return false;
330#endif
331}
332
333std::uint64_t version() {
334#if defined(NCCL_MAJOR)
335 constexpr std::uint64_t ver = (((uint64_t)NCCL_MAJOR) << 32) |
336 (((uint64_t)NCCL_MINOR) << 16) | ((uint64_t)NCCL_PATCH);
337 return ver;
338#elif defined(USE_NCCL)
339 // return major version "1"
340 return ((uint64_t)1) << 32;
341#else
342 return 0;
343#endif
344}
345
346void get_unique_id(ncclUniqueId& id) {
347#ifdef USE_NCCL
348 using namespace torch::cuda::nccl::detail;
349 NCCL_CHECK(ncclGetUniqueId(to_nccl_unique_id(&id)));
350#else
351 AT_ERROR("PyTorch built without NCCL support");
352#endif
353}
354
355ncclComm_t comm_init_rank(int nranks, const ncclUniqueId& comm_id, int rank) {
356#ifdef USE_NCCL
357 using namespace torch::cuda::nccl::detail;
358 ncclComm_t comm;
359 ncclUniqueId id = comm_id;
360 NCCL_CHECK(ncclCommInitRank(
361 to_nccl_comm(&comm), nranks, *(to_nccl_unique_id(&id)), rank));
362 return comm;
363#else
364 return nullptr;
365#endif
366}
367
368void comm_destroy(ncclComm_t comm) {
369 /*
370 * TODO(T30279827) Temporarily disable calling ncclCommDestroy
371 * Calling ncclCommDestroy while program exiting is undefined
372 * according to Nvidia, and lead to segfault in NCCL 2
373 * (whether it is called before or after the CUDA runtime destructor).
374 * Temporarily disable it in destructor to avoid segfault.
375 * Following up with Nvidia for long term solution.
376 */
377 return;
378
379#ifdef USE_NCCL
380 using namespace torch::cuda::nccl::detail;
381 NCCL_CHECK(ncclCommDestroy(to_nccl_comm(comm)));
382#endif
383}
384
385namespace {
386// NCCL changed the numerical type used for count between NCCL1 and NCCL2.
387// So we use the following struct, which gets the type of the second argument
388// of T, if T is a function type, with ncclBcast, to get that type statically
389// and programmatically.
390
391template <typename T>
392struct GetSecondArgType;
393
394template <typename R, typename Arg0, typename Arg1, typename... Args>
395struct GetSecondArgType<R(Arg0, Arg1, Args...)> {
396 typedef typename std::decay<Arg1>::type type;
397};
398
399constexpr auto count_max =
400 std::numeric_limits<GetSecondArgType<decltype(ncclBcast)>::type>::max();
401} // namespace
402
403size_t get_max_count() {
404 return count_max;
405}
406
407void broadcast(
408 TensorList tensors,
409 const stream_list& streams,
410 const comm_list& user_comms) {
411#ifdef USE_NCCL
412 using namespace torch::cuda::nccl::detail;
413 check_inputs(tensors, tensors, 1, 1);
414 auto data_type = to_nccl_data_type(tensors[0]);
415 int64_t numel = tensors[0].numel();
416
417 const auto comms = user_comms.empty() ? get_communicators(tensors)
418 : ArrayRef<ncclComm_t>(user_comms);
419
420 AutoNcclGroup nccl_group_guard;
421 at::cuda::OptionalCUDAGuard device_guard;
422 for (size_t i = 0, num_tensors = tensors.size(); i < num_tensors; i++) {
423 int device = tensors[i].get_device();
424 device_guard.set_index(device);
425 // Default to the current stream
426 const auto stream = (streams.empty() || !streams[i])
427 ? at::cuda::getCurrentCUDAStream(device).stream()
428 : streams[i]->stream();
429 TORCH_CHECK(
430 static_cast<uint64_t>(numel) <= static_cast<uint64_t>(count_max),
431 "Broadcast tensor has ",
432 numel,
433 " elements, which exceeds the "
434 "maximum NCCL supports (",
435 count_max,
436 ")");
437 ncclComm_t comm = comms[i];
438 NCCL_CHECK(ncclBcast(
439 tensors[i].data_ptr(),
440 numel,
441 data_type,
442 0,
443 to_nccl_comm(comm),
444 stream));
445 }
446#else
447 AT_ERROR("PyTorch built without NCCL support");
448#endif
449}
450
451void reduce(
452 const std::vector<at::Tensor>& inputs,
453 at::Tensor& output,
454 int32_t root,
455 int32_t op,
456 const stream_list& streams,
457 const comm_list& user_comms) {
458#ifdef USE_NCCL
459 using namespace torch::cuda::nccl::detail;
460 TORCH_CHECK(
461 root >= 0 && static_cast<size_t>(root) < inputs.size(), "invalid root");
462
463 check_inputs(inputs, output, root, 1, 1);
464 const auto len = inputs.size();
465
466 auto data_type = to_nccl_data_type(inputs[0]);
467
468 const auto count = inputs[0].numel();
469 auto comms_ref = user_comms.empty() ? get_communicators(inputs)
470 : ArrayRef<ncclComm_t>(user_comms);
471
472 AutoNcclGroup nccl_group_guard;
473 at::cuda::OptionalCUDAGuard device_guard;
474 for (const auto i : c10::irange(len)) {
475 int device = inputs[i].device().index();
476 device_guard.set_index(device);
477 // Default to the current stream
478 const auto stream = (streams.empty() || !streams[i])
479 ? at::cuda::getCurrentCUDAStream(device).stream()
480 : streams[i]->stream();
481
482 ncclComm_t comm = comms_ref[i];
483 NCCL_CHECK(ncclReduce(
484 inputs[i].data_ptr(),
485 root == i ? output.data_ptr() : nullptr,
486 count,
487 data_type,
488 to_nccl_red_op(op),
489 root,
490 to_nccl_comm(comm),
491 stream));
492 }
493#else
494 AT_ERROR("PyTorch built without NCCL support");
495#endif
496}
497
498void reduce(
499 std::vector<at::Tensor>& inputs,
500 int32_t root,
501 int32_t op,
502 const stream_list& streams,
503 const comm_list& user_comms) {
504 reduce(inputs, /*output=*/inputs[root], root, op, streams, user_comms);
505}
506
507void all_reduce(
508 const std::vector<at::Tensor>& inputs,
509 std::vector<at::Tensor>& outputs,
510 int32_t op,
511 const stream_list& streams,
512 const comm_list& user_comms) {
513#ifdef USE_NCCL
514 using namespace torch::cuda::nccl::detail;
515 check_inputs(inputs, outputs, 1, 1);
516 const auto len = inputs.size();
517
518 auto data_type = to_nccl_data_type(inputs[0]);
519
520 const auto count = inputs[0].numel();
521 auto comms_ref = user_comms.empty() ? get_communicators(inputs)
522 : ArrayRef<ncclComm_t>(user_comms);
523
524 AutoNcclGroup nccl_group_guard;
525 at::cuda::OptionalCUDAGuard device_guard;
526 for (const auto i : c10::irange(len)) {
527 int device = inputs[i].device().index();
528 device_guard.set_index(device);
529 // Default to the current stream
530 const auto stream = (streams.empty() || !streams[i])
531 ? at::cuda::getCurrentCUDAStream(device).stream()
532 : streams[i]->stream();
533
534 ncclComm_t comm = comms_ref[i];
535 NCCL_CHECK(ncclAllReduce(
536 inputs[i].data_ptr(),
537 outputs[i].data_ptr(),
538 count,
539 data_type,
540 to_nccl_red_op(op),
541 to_nccl_comm(comm),
542 stream));
543 }
544#else
545 AT_ERROR("PyTorch built without NCCL support");
546#endif
547}
548
549void reduce_scatter(
550 const std::vector<at::Tensor>& inputs,
551 std::vector<at::Tensor>& outputs,
552 int32_t op,
553 const stream_list& streams,
554 const comm_list& user_comms) {
555#ifdef USE_NCCL
556 using namespace torch::cuda::nccl::detail;
557 const auto len = inputs.size();
558 check_inputs(inputs, outputs, 1, len);
559
560 auto data_type = to_nccl_data_type(inputs[0]);
561
562 const auto count = inputs[0].numel() / len;
563 auto comms_ref = user_comms.empty() ? get_communicators(inputs)
564 : ArrayRef<ncclComm_t>(user_comms);
565
566 AutoNcclGroup nccl_group_guard;
567 at::cuda::OptionalCUDAGuard device_guard;
568 for (const auto i : c10::irange(len)) {
569 int device = inputs[i].device().index();
570 device_guard.set_index(device);
571 // Default to the current stream
572 const auto stream = (streams.empty() || !streams[i])
573 ? at::cuda::getCurrentCUDAStream(device).stream()
574 : streams[i]->stream();
575
576 ncclComm_t comm = comms_ref[i];
577 NCCL_CHECK(ncclReduceScatter(
578 inputs[i].data_ptr(),
579 outputs[i].data_ptr(),
580 count,
581 data_type,
582 to_nccl_red_op(op),
583 to_nccl_comm(comm),
584 stream));
585 }
586#else
587 AT_ERROR("PyTorch built without NCCL support");
588#endif
589}
590
591void all_gather(
592 const std::vector<at::Tensor>& inputs,
593 std::vector<at::Tensor>& outputs,
594 const stream_list& streams,
595 const comm_list& user_comms) {
596#ifdef USE_NCCL
597 using namespace torch::cuda::nccl::detail;
598 const auto len = inputs.size();
599 check_inputs(inputs, outputs, len, 1);
600
601 auto data_type = to_nccl_data_type(inputs[0]);
602
603 const auto count = inputs[0].numel();
604 auto comms_ref = user_comms.empty() ? get_communicators(inputs)
605 : ArrayRef<ncclComm_t>(user_comms);
606
607 AutoNcclGroup nccl_group_guard;
608 at::cuda::OptionalCUDAGuard device_guard;
609 for (const auto i : c10::irange(len)) {
610 int device = inputs[i].device().index();
611 device_guard.set_index(device);
612 // Default to the current stream
613 const auto stream = (streams.empty() || !streams[i])
614 ? at::cuda::getCurrentCUDAStream(device).stream()
615 : streams[i]->stream();
616
617 ncclComm_t comm = comms_ref[i];
618#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
619 NCCL_CHECK(ncclAllGather(
620 inputs[i].data_ptr(),
621 outputs[i].data_ptr(),
622 count,
623 data_type,
624 to_nccl_comm(comm),
625 stream));
626#else
627 NCCL_CHECK(ncclAllGather(
628 inputs[i].data_ptr(),
629 count,
630 data_type,
631 outputs[i].data_ptr(),
632 to_nccl_comm(comm),
633 stream));
634#endif
635 }
636#else
637 AT_ERROR("PyTorch built without NCCL support");
638#endif
639}
640
641void all2all_single_equal_split(
642 at::Tensor& input,
643 at::Tensor& output,
644 int size,
645 ncclComm_t _comm,
646 at::cuda::CUDAStream& stream) {
647#ifdef USE_NCCL
648#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && \
649 (NCCL_MAJOR * 10 + NCCL_MINOR) >= 27
650 using namespace torch::cuda::nccl::detail;
651
652 int numranks;
653 auto type = to_nccl_data_type(input);
654 size_t count = input.numel() / size;
655 size_t rankdiff = input.nbytes() / size;
656 const auto* sendbuff = reinterpret_cast<char*>(input.data_ptr());
657 auto* recvbuff = reinterpret_cast<char*>(output.data_ptr());
658 auto comm = to_nccl_comm(_comm);
659#if defined(USE_ROCM) && ROCM_VERSION >= 50000
660 NCCL_CHECK(ncclAllToAll(sendbuff, recvbuff, count, type, comm, stream));
661#else
662 NCCL_CHECK(ncclCommCount(comm, &numranks));
663 NCCL_CHECK(ncclGroupStart());
664 for (const auto r : c10::irange(numranks)) {
665 // NCCL uses 0 byte message for synchronization
666 // Avoid send/recv when message size is zero
667 if (count != 0) {
668 NCCL_CHECK(
669 ncclSend(sendbuff + r * rankdiff, count, type, r, comm, stream));
670 NCCL_CHECK(
671 ncclRecv(recvbuff + r * rankdiff, count, type, r, comm, stream));
672 }
673 }
674 NCCL_CHECK(ncclGroupEnd());
675#endif
676#else
677 AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0");
678#endif
679#else
680 AT_ERROR("PyTorch built without NCCL support");
681#endif
682}
683
684void all2all_single_unequal_split(
685 void* sendbuff,
686 const size_t* sendcounts,
687 const size_t* senddispls,
688 void* recvbuff,
689 const size_t* recvcounts,
690 const size_t* recvdispls,
691 size_t size,
692 c10::ScalarType _type,
693 ncclComm_t _comm,
694 at::cuda::CUDAStream& stream) {
695#ifdef USE_NCCL
696#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && \
697 (NCCL_MAJOR * 10 + NCCL_MINOR) >= 27
698 using namespace torch::cuda::nccl::detail;
699
700 auto type = to_nccl_data_type(_type);
701 auto comm = to_nccl_comm(_comm);
702 int numranks;
703 NCCL_CHECK(ncclCommCount(comm, &numranks));
704 NCCL_CHECK(ncclGroupStart());
705 for (const auto r : c10::irange(numranks)) {
706 // NCCL uses 0 byte message for synchronization
707 // Avoid send/recv when message size is zero
708 if (sendcounts[r] != 0) {
709 NCCL_CHECK(ncclSend(
710 ((char*)sendbuff) + senddispls[r] * size,
711 sendcounts[r],
712 type,
713 r,
714 comm,
715 stream));
716 }
717 if (recvcounts[r] != 0) {
718 NCCL_CHECK(ncclRecv(
719 ((char*)recvbuff) + recvdispls[r] * size,
720 recvcounts[r],
721 type,
722 r,
723 comm,
724 stream));
725 }
726 }
727 NCCL_CHECK(ncclGroupEnd());
728#else
729 AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0");
730#endif
731#else
732 AT_ERROR("PyTorch built without NCCL support");
733#endif
734}
735
736void all2all(
737 std::vector<at::Tensor>& outputTensors,
738 std::vector<at::Tensor>& inputTensors,
739 ncclComm_t _comm,
740 at::cuda::CUDAStream& stream) {
741#ifdef USE_NCCL
742#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && \
743 (NCCL_MAJOR * 10 + NCCL_MINOR) >= 27
744 using namespace torch::cuda::nccl::detail;
745 auto comm = to_nccl_comm(_comm);
746
747 NCCL_CHECK(ncclGroupStart());
748 for (const auto r : c10::irange(outputTensors.size())) {
749 at::Tensor& input = inputTensors[r];
750 at::Tensor& output = outputTensors[r];
751 if (input.numel() != 0) {
752 NCCL_CHECK(ncclSend(
753 input.data_ptr(),
754 input.numel(),
755 to_nccl_data_type(input),
756 r,
757 comm,
758 stream.stream()));
759 }
760 if (output.numel() != 0) {
761 NCCL_CHECK(ncclRecv(
762 output.data_ptr(),
763 output.numel(),
764 to_nccl_data_type(output),
765 r,
766 comm,
767 stream.stream()));
768 }
769 }
770 NCCL_CHECK(ncclGroupEnd());
771#else
772 AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0");
773#endif
774#else
775 AT_ERROR("PyTorch built without NCCL support");
776#endif
777}
778
779void send(
780 const at::Tensor& input,
781 ncclComm_t comm,
782 at::cuda::CUDAStream stream,
783 int dst) {
784#ifdef USE_NCCL
785#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
786 (NCCL_MINOR >= 7)
787 using namespace torch::cuda::nccl::detail;
788 NCCL_CHECK(ncclSend(
789 input.data_ptr(),
790 input.numel(),
791 to_nccl_data_type(input),
792 dst,
793 to_nccl_comm(comm),
794 stream.stream()));
795#else
796 AT_ERROR("Send is only supported for NCCL lib version >= 2.7.0");
797#endif
798#else
799 AT_ERROR("PyTorch built without NCCL support");
800#endif
801}
802
803void recv(
804 at::Tensor& output,
805 ncclComm_t comm,
806 at::cuda::CUDAStream stream,
807 int src) {
808#ifdef USE_NCCL
809#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
810 (NCCL_MINOR >= 7)
811 using namespace torch::cuda::nccl::detail;
812 NCCL_CHECK(ncclRecv(
813 output.data_ptr(),
814 output.numel(),
815 to_nccl_data_type(output),
816 src,
817 to_nccl_comm(comm),
818 stream.stream()));
819#else
820 AT_ERROR("Recv is only supported for NCCL lib version >= 2.7.0");
821#endif
822#else
823 AT_ERROR("PyTorch built without NCCL support");
824#endif
825}
826
827void gather(
828 const at::Tensor& inputs,
829 std::vector<at::Tensor>& outputs,
830 ncclComm_t _comm,
831 at::cuda::CUDAStream& stream,
832 int32_t root) {
833#ifdef USE_NCCL
834#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && \
835 (NCCL_MAJOR * 10 + NCCL_MINOR) >= 27
836 using namespace torch::cuda::nccl::detail;
837
838 auto comm = to_nccl_comm(_comm);
839 int numranks, cur_rank;
840 NCCL_CHECK(ncclCommCount(comm, &numranks));
841 NCCL_CHECK(ncclCommUserRank(comm, &cur_rank));
842
843 size_t count = inputs.numel();
844 auto type = to_nccl_data_type(inputs);
845 const auto* sendbuff = reinterpret_cast<char*>(inputs.data_ptr());
846
847 NCCL_CHECK(ncclGroupStart());
848
849 if (cur_rank == root) {
850 for (const auto r : c10::irange(numranks)) {
851 if (r != root) {
852 auto* recvbuff = reinterpret_cast<char*>(outputs[r].data_ptr());
853 NCCL_CHECK(ncclRecv(recvbuff, count, type, r, comm, stream));
854 } else {
855 // on its own rank, simply copy from the input
856 outputs[r].copy_(inputs);
857 }
858 }
859 } else {
860 NCCL_CHECK(ncclSend(sendbuff, count, type, root, comm, stream));
861 }
862 NCCL_CHECK(ncclGroupEnd());
863
864#else
865 AT_ERROR("gather is only supported for NCCL lib version >= 2.7.0");
866#endif
867#else
868 AT_ERROR("PyTorch built without NCCL support");
869#endif
870}
871
872void scatter(
873 const std::vector<at::Tensor>& inputs,
874 at::Tensor& outputs,
875 ncclComm_t _comm,
876 at::cuda::CUDAStream& stream,
877 int32_t root) {
878#ifdef USE_NCCL
879#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && \
880 (NCCL_MAJOR * 10 + NCCL_MINOR) >= 27
881 using namespace torch::cuda::nccl::detail;
882
883 auto comm = to_nccl_comm(_comm);
884 int numranks, cur_rank;
885 NCCL_CHECK(ncclCommCount(comm, &numranks));
886 NCCL_CHECK(ncclCommUserRank(comm, &cur_rank));
887
888 NCCL_CHECK(ncclGroupStart());
889 if (cur_rank == root) {
890 for (const auto r : c10::irange(numranks)) {
891 if (r != root) {
892 size_t send_count = inputs[r].numel();
893 auto send_type = to_nccl_data_type(inputs[r]);
894 const auto* sendbuff = reinterpret_cast<char*>(inputs[r].data_ptr());
895 NCCL_CHECK(ncclSend(sendbuff, send_count, send_type, r, comm, stream));
896 } else {
897 // on its own rank, simply copy it to the output
898 outputs.copy_(inputs[r]);
899 }
900 }
901 } else {
902 size_t recv_count = outputs.numel();
903 auto recv_type = to_nccl_data_type(outputs);
904 auto* recvbuff = reinterpret_cast<char*>(outputs.data_ptr());
905 NCCL_CHECK(ncclRecv(recvbuff, recv_count, recv_type, root, comm, stream));
906 }
907 NCCL_CHECK(ncclGroupEnd());
908
909#else
910 AT_ERROR("scatter is only supported for NCCL lib version >= 2.7.0");
911#endif
912#else
913 AT_ERROR("PyTorch built without NCCL support");
914#endif
915}
916
917} // namespace nccl
918} // namespace cuda
919} // namespace torch
920