1 | #include <ATen/core/functional.h> |
2 | #include <torch/csrc/cuda/device_set.h> |
3 | #include <torch/csrc/cuda/nccl.h> |
4 | |
5 | #include <ATen/ATen.h> |
6 | #include <c10/cuda/CUDAException.h> |
7 | #include <c10/cuda/CUDAGuard.h> |
8 | #include <c10/util/Exception.h> |
9 | #include <c10/util/hash.h> |
10 | #include <c10/util/irange.h> |
11 | |
12 | #include <nccl.h> |
13 | |
14 | #include <limits> |
15 | #include <sstream> |
16 | #include <type_traits> |
17 | #include <unordered_map> |
18 | |
19 | ncclComm_t* to_nccl_comm(torch::cuda::nccl::ncclComm_t* var) { |
20 | return reinterpret_cast<ncclComm_t*>(var); |
21 | } |
22 | |
23 | ncclComm_t to_nccl_comm(torch::cuda::nccl::ncclComm_t var) { |
24 | return reinterpret_cast<ncclComm_t>(var); |
25 | } |
26 | |
27 | ncclUniqueId* to_nccl_unique_id(torch::cuda::nccl::ncclUniqueId* var) { |
28 | return reinterpret_cast<ncclUniqueId*>(var); |
29 | } |
30 | |
31 | ncclResult_t to_nccl_result(torch::cuda::nccl::ncclResult var) { |
32 | switch (var) { |
33 | case torch::cuda::nccl::ncclResult::Success: |
34 | return ncclResult_t::ncclSuccess; |
35 | case torch::cuda::nccl::ncclResult::UnhandledCudaError: |
36 | return ncclResult_t::ncclUnhandledCudaError; |
37 | case torch::cuda::nccl::ncclResult::SystemError: |
38 | return ncclResult_t::ncclSystemError; |
39 | case torch::cuda::nccl::ncclResult::InternalError: |
40 | return ncclResult_t::ncclInternalError; |
41 | case torch::cuda::nccl::ncclResult::InvalidArgument: |
42 | return ncclResult_t::ncclInvalidArgument; |
43 | case torch::cuda::nccl::ncclResult::InvalidUsage: |
44 | return ncclResult_t::ncclInvalidUsage; |
45 | case torch::cuda::nccl::ncclResult::NumResults: |
46 | return ncclResult_t::ncclNumResults; |
47 | default: |
48 | throw std::runtime_error("Unconvertible NCCL type" ); |
49 | } |
50 | } |
51 | |
52 | torch::cuda::nccl::ncclResult from_nccl_result(ncclResult_t var) { |
53 | switch (var) { |
54 | case ncclSuccess: |
55 | return torch::cuda::nccl::ncclResult::Success; |
56 | case ncclUnhandledCudaError: |
57 | return torch::cuda::nccl::ncclResult::UnhandledCudaError; |
58 | case ncclSystemError: |
59 | return torch::cuda::nccl::ncclResult::SystemError; |
60 | case ncclInternalError: |
61 | return torch::cuda::nccl::ncclResult::InternalError; |
62 | case ncclInvalidArgument: |
63 | return torch::cuda::nccl::ncclResult::InvalidArgument; |
64 | case ncclInvalidUsage: |
65 | return torch::cuda::nccl::ncclResult::InvalidUsage; |
66 | case ncclNumResults: |
67 | return torch::cuda::nccl::ncclResult::NumResults; |
68 | default: |
69 | throw std::runtime_error("Unconvertible NCCL type" ); |
70 | } |
71 | } |
72 | |
73 | ncclDataType_t to_nccl_data_type(c10::ScalarType type) { |
74 | switch (type) { |
75 | case at::kFloat: |
76 | return ncclDataType_t::ncclFloat; |
77 | case at::kHalf: |
78 | return ncclDataType_t::ncclHalf; |
79 | case at::kDouble: |
80 | return ncclDataType_t::ncclDouble; |
81 | case at::kLong: |
82 | return ncclDataType_t::ncclInt64; |
83 | case at::kInt: |
84 | return ncclDataType_t::ncclInt; |
85 | case at::kChar: |
86 | return ncclDataType_t::ncclChar; |
87 | case at::kByte: |
88 | return ncclDataType_t::ncclUint8; |
89 | case at::kBool: |
90 | return ncclDataType_t::ncclUint8; |
91 | #if HAS_NCCL_BF16_DATATYPE |
92 | case at::kBFloat16: |
93 | return ncclDataType_t::ncclBfloat16; |
94 | #endif |
95 | default: |
96 | TORCH_CHECK(false, "Unconvertible NCCL type " , type); |
97 | } |
98 | } |
99 | |
100 | ncclDataType_t to_nccl_data_type(const at::Tensor& t) { |
101 | if (!t.is_cuda()) { |
102 | TORCH_CHECK( |
103 | false, |
104 | "NCCL only supports CUDA tensors, but got a tensor on " , |
105 | t.device()); |
106 | } |
107 | return to_nccl_data_type(t.scalar_type()); |
108 | } |
109 | |
110 | ncclRedOp_t to_nccl_red_op(int var) { |
111 | return (ncclRedOp_t)(var); |
112 | } |
113 | |
114 | namespace torch { |
115 | namespace cuda { |
116 | namespace nccl { |
117 | |
118 | using namespace at; |
119 | |
120 | namespace detail { |
121 | |
122 | static inline void NCCL_CHECK(ncclResult_t result) { |
123 | NCCL_CHECK(from_nccl_result(result)); |
124 | } |
125 | |
126 | void throw_nccl_error(torch::cuda::nccl::ncclResult status) { |
127 | std::ostringstream err; |
128 | err << "NCCL Error " << static_cast<int>(status) << ": " |
129 | << ncclGetErrorString(to_nccl_result(status)); |
130 | throw std::runtime_error(err.str()); |
131 | } |
132 | |
133 | struct NcclCommList { |
134 | std::unique_ptr<ncclComm_t[]> comms; |
135 | int ndevices; |
136 | NcclCommList(const std::vector<int>& devices) |
137 | : comms(new ncclComm_t[devices.size()]), ndevices(devices.size()) { |
138 | NCCL_CHECK(ncclCommInitAll( |
139 | to_nccl_comm(comms.get()), devices.size(), devices.data())); |
140 | } |
141 | NcclCommList(NcclCommList&& foo) = default; |
142 | ~NcclCommList() { |
143 | if (comms) { |
144 | for (const auto i : c10::irange(ndevices)) { |
145 | int dummy_var; |
146 | if (C10_CUDA_ERROR_HANDLED(cudaGetDevice(&dummy_var)) != cudaSuccess) { |
147 | /* there are cases when this destructor is called after the |
148 | CUDA driver is already unloaded from the process. |
149 | In these cases, skip ncclCommDestroy */ |
150 | return; |
151 | } |
152 | comm_destroy(comms[i]); |
153 | } |
154 | } |
155 | } |
156 | ArrayRef<ncclComm_t> ref() const { |
157 | return ArrayRef<ncclComm_t>(comms.get(), ndevices); |
158 | } |
159 | }; |
160 | |
161 | using device_list = std::vector<int>; |
162 | // accesses to this object have to be guarded by THC's CudaFreeMutex |
163 | static std::unordered_map<device_list, NcclCommList, c10::hash<device_list>> |
164 | _communicators; |
165 | |
166 | ArrayRef<ncclComm_t> get_communicators(TensorList inputs) { |
167 | static auto get_device = [](const at::Tensor& t) -> int { |
168 | return t.get_device(); |
169 | }; |
170 | device_list devices = fmap(inputs, get_device); |
171 | auto it = _communicators.find(devices); |
172 | if (it == _communicators.end()) |
173 | std::tie(it, std::ignore) = _communicators.emplace(devices, devices); |
174 | return it->second.ref(); |
175 | } |
176 | |
177 | static inline void check_tensor( |
178 | const at::Tensor& input, |
179 | const at::optional<at::Tensor>& output, |
180 | int input_multiplier, |
181 | int output_multiplier, |
182 | int64_t ref_numel, |
183 | ScalarType ref_dtype) { |
184 | auto check_one = [&](const at::Tensor& tensor) { |
185 | if (!tensor.is_cuda() || tensor.is_sparse()) { |
186 | throw std::runtime_error( |
187 | "input and output elements have to be cuda dense Tensors" ); |
188 | } |
189 | |
190 | if (ref_dtype != tensor.scalar_type()) { |
191 | throw std::runtime_error( |
192 | "all inputs and outputs must be of the same Tensor dtype" ); |
193 | } |
194 | |
195 | if (!tensor.is_contiguous()) { |
196 | throw std::runtime_error("all inputs and outputs have to be contiguous" ); |
197 | } |
198 | }; |
199 | |
200 | check_one(input); |
201 | |
202 | // all inputs must be same size |
203 | if (input.numel() != ref_numel) { |
204 | throw std::runtime_error( |
205 | "all inputs must have the same number of elements" ); |
206 | } |
207 | |
208 | if (output) { |
209 | check_one(*output); |
210 | |
211 | // inputs and outputs must be on same device respectively |
212 | if (input.get_device() != output->get_device()) { |
213 | throw std::runtime_error("input and output must be on the same device" ); |
214 | } |
215 | |
216 | if (output->numel() * output_multiplier != ref_numel * input_multiplier) { |
217 | throw std::runtime_error( |
218 | "output must be of size input_size * size_multiplier" ); |
219 | } |
220 | } |
221 | } |
222 | |
223 | void check_inputs( |
224 | TensorList inputs, |
225 | TensorList outputs, |
226 | int input_multiplier, |
227 | int output_multiplier) { |
228 | // len(inputs) == len(outputs) |
229 | size_t len = inputs.size(); |
230 | |
231 | if (len <= 0) { |
232 | throw std::runtime_error("input sequence can't be empty" ); |
233 | } |
234 | |
235 | if (len != outputs.size()) { |
236 | std::stringstream err; |
237 | err << "inputs and outputs sequences have to be of the same length, but got input of length " |
238 | << len << " and output of length " << outputs.size(); |
239 | throw std::runtime_error(err.str()); |
240 | } |
241 | |
242 | device_set devices; |
243 | int64_t numel = inputs[0].numel(); |
244 | auto dtype = inputs[0].scalar_type(); |
245 | |
246 | for (const auto i : c10::irange(len)) { |
247 | auto input = inputs[i]; |
248 | auto output = outputs[i]; |
249 | |
250 | check_tensor( |
251 | input, output, input_multiplier, output_multiplier, numel, dtype); |
252 | |
253 | auto input_device = input.get_device(); |
254 | // inputs must be on unique devices |
255 | if (devices.test(input_device)) { |
256 | throw std::runtime_error("inputs must be on unique devices" ); |
257 | } |
258 | devices.set(input_device); |
259 | } |
260 | } |
261 | |
262 | void check_inputs( |
263 | TensorList inputs, |
264 | const at::Tensor& output, |
265 | int root, |
266 | int input_multiplier, |
267 | int output_multiplier) { |
268 | size_t len = inputs.size(); |
269 | |
270 | if (len <= 0) { |
271 | throw std::runtime_error("input sequence can't be empty" ); |
272 | } |
273 | |
274 | device_set devices; |
275 | int64_t numel = inputs[0].numel(); |
276 | auto dtype = inputs[0].scalar_type(); |
277 | |
278 | for (const auto i : c10::irange(len)) { |
279 | auto input = inputs[i]; |
280 | |
281 | check_tensor( |
282 | input, |
283 | i == root ? at::optional<at::Tensor>{output} : at::nullopt, |
284 | input_multiplier, |
285 | output_multiplier, |
286 | numel, |
287 | dtype); |
288 | |
289 | auto input_device = input.get_device(); |
290 | // inputs must be on unique devices |
291 | if (devices.test(input_device)) { |
292 | throw std::runtime_error("inputs must be on unique devices" ); |
293 | } |
294 | devices.set(input_device); |
295 | } |
296 | } |
297 | |
298 | } // namespace detail |
299 | |
300 | AutoNcclGroup::AutoNcclGroup() { |
301 | (c10::cuda::getFreeMutex())->lock(); |
302 | #if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2) |
303 | detail::NCCL_CHECK(ncclGroupStart()); |
304 | #endif |
305 | } |
306 | |
307 | AutoNcclGroup::~AutoNcclGroup() noexcept(false) { |
308 | #if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2) |
309 | detail::NCCL_CHECK(ncclGroupEnd()); |
310 | #endif |
311 | (c10::cuda::getFreeMutex())->unlock(); |
312 | } |
313 | |
314 | bool is_available(TensorList tensors) { |
315 | #ifdef USE_NCCL |
316 | device_set devices; |
317 | for (auto& tensor : tensors) { |
318 | if (!tensor.is_cuda() || tensor.is_sparse()) |
319 | return false; |
320 | if (!tensor.is_contiguous()) |
321 | return false; |
322 | auto device = tensor.get_device(); |
323 | if (devices[device]) |
324 | return false; |
325 | devices[device] = true; |
326 | } |
327 | return true; |
328 | #else |
329 | return false; |
330 | #endif |
331 | } |
332 | |
333 | std::uint64_t version() { |
334 | #if defined(NCCL_MAJOR) |
335 | constexpr std::uint64_t ver = (((uint64_t)NCCL_MAJOR) << 32) | |
336 | (((uint64_t)NCCL_MINOR) << 16) | ((uint64_t)NCCL_PATCH); |
337 | return ver; |
338 | #elif defined(USE_NCCL) |
339 | // return major version "1" |
340 | return ((uint64_t)1) << 32; |
341 | #else |
342 | return 0; |
343 | #endif |
344 | } |
345 | |
346 | void get_unique_id(ncclUniqueId& id) { |
347 | #ifdef USE_NCCL |
348 | using namespace torch::cuda::nccl::detail; |
349 | NCCL_CHECK(ncclGetUniqueId(to_nccl_unique_id(&id))); |
350 | #else |
351 | AT_ERROR("PyTorch built without NCCL support" ); |
352 | #endif |
353 | } |
354 | |
355 | ncclComm_t comm_init_rank(int nranks, const ncclUniqueId& comm_id, int rank) { |
356 | #ifdef USE_NCCL |
357 | using namespace torch::cuda::nccl::detail; |
358 | ncclComm_t comm; |
359 | ncclUniqueId id = comm_id; |
360 | NCCL_CHECK(ncclCommInitRank( |
361 | to_nccl_comm(&comm), nranks, *(to_nccl_unique_id(&id)), rank)); |
362 | return comm; |
363 | #else |
364 | return nullptr; |
365 | #endif |
366 | } |
367 | |
368 | void comm_destroy(ncclComm_t comm) { |
369 | /* |
370 | * TODO(T30279827) Temporarily disable calling ncclCommDestroy |
371 | * Calling ncclCommDestroy while program exiting is undefined |
372 | * according to Nvidia, and lead to segfault in NCCL 2 |
373 | * (whether it is called before or after the CUDA runtime destructor). |
374 | * Temporarily disable it in destructor to avoid segfault. |
375 | * Following up with Nvidia for long term solution. |
376 | */ |
377 | return; |
378 | |
379 | #ifdef USE_NCCL |
380 | using namespace torch::cuda::nccl::detail; |
381 | NCCL_CHECK(ncclCommDestroy(to_nccl_comm(comm))); |
382 | #endif |
383 | } |
384 | |
385 | namespace { |
386 | // NCCL changed the numerical type used for count between NCCL1 and NCCL2. |
387 | // So we use the following struct, which gets the type of the second argument |
388 | // of T, if T is a function type, with ncclBcast, to get that type statically |
389 | // and programmatically. |
390 | |
391 | template <typename T> |
392 | struct GetSecondArgType; |
393 | |
394 | template <typename R, typename Arg0, typename Arg1, typename... Args> |
395 | struct GetSecondArgType<R(Arg0, Arg1, Args...)> { |
396 | typedef typename std::decay<Arg1>::type type; |
397 | }; |
398 | |
399 | constexpr auto count_max = |
400 | std::numeric_limits<GetSecondArgType<decltype(ncclBcast)>::type>::max(); |
401 | } // namespace |
402 | |
403 | size_t get_max_count() { |
404 | return count_max; |
405 | } |
406 | |
407 | void broadcast( |
408 | TensorList tensors, |
409 | const stream_list& streams, |
410 | const comm_list& user_comms) { |
411 | #ifdef USE_NCCL |
412 | using namespace torch::cuda::nccl::detail; |
413 | check_inputs(tensors, tensors, 1, 1); |
414 | auto data_type = to_nccl_data_type(tensors[0]); |
415 | int64_t numel = tensors[0].numel(); |
416 | |
417 | const auto comms = user_comms.empty() ? get_communicators(tensors) |
418 | : ArrayRef<ncclComm_t>(user_comms); |
419 | |
420 | AutoNcclGroup nccl_group_guard; |
421 | at::cuda::OptionalCUDAGuard device_guard; |
422 | for (size_t i = 0, num_tensors = tensors.size(); i < num_tensors; i++) { |
423 | int device = tensors[i].get_device(); |
424 | device_guard.set_index(device); |
425 | // Default to the current stream |
426 | const auto stream = (streams.empty() || !streams[i]) |
427 | ? at::cuda::getCurrentCUDAStream(device).stream() |
428 | : streams[i]->stream(); |
429 | TORCH_CHECK( |
430 | static_cast<uint64_t>(numel) <= static_cast<uint64_t>(count_max), |
431 | "Broadcast tensor has " , |
432 | numel, |
433 | " elements, which exceeds the " |
434 | "maximum NCCL supports (" , |
435 | count_max, |
436 | ")" ); |
437 | ncclComm_t comm = comms[i]; |
438 | NCCL_CHECK(ncclBcast( |
439 | tensors[i].data_ptr(), |
440 | numel, |
441 | data_type, |
442 | 0, |
443 | to_nccl_comm(comm), |
444 | stream)); |
445 | } |
446 | #else |
447 | AT_ERROR("PyTorch built without NCCL support" ); |
448 | #endif |
449 | } |
450 | |
451 | void reduce( |
452 | const std::vector<at::Tensor>& inputs, |
453 | at::Tensor& output, |
454 | int32_t root, |
455 | int32_t op, |
456 | const stream_list& streams, |
457 | const comm_list& user_comms) { |
458 | #ifdef USE_NCCL |
459 | using namespace torch::cuda::nccl::detail; |
460 | TORCH_CHECK( |
461 | root >= 0 && static_cast<size_t>(root) < inputs.size(), "invalid root" ); |
462 | |
463 | check_inputs(inputs, output, root, 1, 1); |
464 | const auto len = inputs.size(); |
465 | |
466 | auto data_type = to_nccl_data_type(inputs[0]); |
467 | |
468 | const auto count = inputs[0].numel(); |
469 | auto comms_ref = user_comms.empty() ? get_communicators(inputs) |
470 | : ArrayRef<ncclComm_t>(user_comms); |
471 | |
472 | AutoNcclGroup nccl_group_guard; |
473 | at::cuda::OptionalCUDAGuard device_guard; |
474 | for (const auto i : c10::irange(len)) { |
475 | int device = inputs[i].device().index(); |
476 | device_guard.set_index(device); |
477 | // Default to the current stream |
478 | const auto stream = (streams.empty() || !streams[i]) |
479 | ? at::cuda::getCurrentCUDAStream(device).stream() |
480 | : streams[i]->stream(); |
481 | |
482 | ncclComm_t comm = comms_ref[i]; |
483 | NCCL_CHECK(ncclReduce( |
484 | inputs[i].data_ptr(), |
485 | root == i ? output.data_ptr() : nullptr, |
486 | count, |
487 | data_type, |
488 | to_nccl_red_op(op), |
489 | root, |
490 | to_nccl_comm(comm), |
491 | stream)); |
492 | } |
493 | #else |
494 | AT_ERROR("PyTorch built without NCCL support" ); |
495 | #endif |
496 | } |
497 | |
498 | void reduce( |
499 | std::vector<at::Tensor>& inputs, |
500 | int32_t root, |
501 | int32_t op, |
502 | const stream_list& streams, |
503 | const comm_list& user_comms) { |
504 | reduce(inputs, /*output=*/inputs[root], root, op, streams, user_comms); |
505 | } |
506 | |
507 | void all_reduce( |
508 | const std::vector<at::Tensor>& inputs, |
509 | std::vector<at::Tensor>& outputs, |
510 | int32_t op, |
511 | const stream_list& streams, |
512 | const comm_list& user_comms) { |
513 | #ifdef USE_NCCL |
514 | using namespace torch::cuda::nccl::detail; |
515 | check_inputs(inputs, outputs, 1, 1); |
516 | const auto len = inputs.size(); |
517 | |
518 | auto data_type = to_nccl_data_type(inputs[0]); |
519 | |
520 | const auto count = inputs[0].numel(); |
521 | auto comms_ref = user_comms.empty() ? get_communicators(inputs) |
522 | : ArrayRef<ncclComm_t>(user_comms); |
523 | |
524 | AutoNcclGroup nccl_group_guard; |
525 | at::cuda::OptionalCUDAGuard device_guard; |
526 | for (const auto i : c10::irange(len)) { |
527 | int device = inputs[i].device().index(); |
528 | device_guard.set_index(device); |
529 | // Default to the current stream |
530 | const auto stream = (streams.empty() || !streams[i]) |
531 | ? at::cuda::getCurrentCUDAStream(device).stream() |
532 | : streams[i]->stream(); |
533 | |
534 | ncclComm_t comm = comms_ref[i]; |
535 | NCCL_CHECK(ncclAllReduce( |
536 | inputs[i].data_ptr(), |
537 | outputs[i].data_ptr(), |
538 | count, |
539 | data_type, |
540 | to_nccl_red_op(op), |
541 | to_nccl_comm(comm), |
542 | stream)); |
543 | } |
544 | #else |
545 | AT_ERROR("PyTorch built without NCCL support" ); |
546 | #endif |
547 | } |
548 | |
549 | void reduce_scatter( |
550 | const std::vector<at::Tensor>& inputs, |
551 | std::vector<at::Tensor>& outputs, |
552 | int32_t op, |
553 | const stream_list& streams, |
554 | const comm_list& user_comms) { |
555 | #ifdef USE_NCCL |
556 | using namespace torch::cuda::nccl::detail; |
557 | const auto len = inputs.size(); |
558 | check_inputs(inputs, outputs, 1, len); |
559 | |
560 | auto data_type = to_nccl_data_type(inputs[0]); |
561 | |
562 | const auto count = inputs[0].numel() / len; |
563 | auto comms_ref = user_comms.empty() ? get_communicators(inputs) |
564 | : ArrayRef<ncclComm_t>(user_comms); |
565 | |
566 | AutoNcclGroup nccl_group_guard; |
567 | at::cuda::OptionalCUDAGuard device_guard; |
568 | for (const auto i : c10::irange(len)) { |
569 | int device = inputs[i].device().index(); |
570 | device_guard.set_index(device); |
571 | // Default to the current stream |
572 | const auto stream = (streams.empty() || !streams[i]) |
573 | ? at::cuda::getCurrentCUDAStream(device).stream() |
574 | : streams[i]->stream(); |
575 | |
576 | ncclComm_t comm = comms_ref[i]; |
577 | NCCL_CHECK(ncclReduceScatter( |
578 | inputs[i].data_ptr(), |
579 | outputs[i].data_ptr(), |
580 | count, |
581 | data_type, |
582 | to_nccl_red_op(op), |
583 | to_nccl_comm(comm), |
584 | stream)); |
585 | } |
586 | #else |
587 | AT_ERROR("PyTorch built without NCCL support" ); |
588 | #endif |
589 | } |
590 | |
591 | void all_gather( |
592 | const std::vector<at::Tensor>& inputs, |
593 | std::vector<at::Tensor>& outputs, |
594 | const stream_list& streams, |
595 | const comm_list& user_comms) { |
596 | #ifdef USE_NCCL |
597 | using namespace torch::cuda::nccl::detail; |
598 | const auto len = inputs.size(); |
599 | check_inputs(inputs, outputs, len, 1); |
600 | |
601 | auto data_type = to_nccl_data_type(inputs[0]); |
602 | |
603 | const auto count = inputs[0].numel(); |
604 | auto comms_ref = user_comms.empty() ? get_communicators(inputs) |
605 | : ArrayRef<ncclComm_t>(user_comms); |
606 | |
607 | AutoNcclGroup nccl_group_guard; |
608 | at::cuda::OptionalCUDAGuard device_guard; |
609 | for (const auto i : c10::irange(len)) { |
610 | int device = inputs[i].device().index(); |
611 | device_guard.set_index(device); |
612 | // Default to the current stream |
613 | const auto stream = (streams.empty() || !streams[i]) |
614 | ? at::cuda::getCurrentCUDAStream(device).stream() |
615 | : streams[i]->stream(); |
616 | |
617 | ncclComm_t comm = comms_ref[i]; |
618 | #if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2) |
619 | NCCL_CHECK(ncclAllGather( |
620 | inputs[i].data_ptr(), |
621 | outputs[i].data_ptr(), |
622 | count, |
623 | data_type, |
624 | to_nccl_comm(comm), |
625 | stream)); |
626 | #else |
627 | NCCL_CHECK(ncclAllGather( |
628 | inputs[i].data_ptr(), |
629 | count, |
630 | data_type, |
631 | outputs[i].data_ptr(), |
632 | to_nccl_comm(comm), |
633 | stream)); |
634 | #endif |
635 | } |
636 | #else |
637 | AT_ERROR("PyTorch built without NCCL support" ); |
638 | #endif |
639 | } |
640 | |
641 | void all2all_single_equal_split( |
642 | at::Tensor& input, |
643 | at::Tensor& output, |
644 | int size, |
645 | ncclComm_t _comm, |
646 | at::cuda::CUDAStream& stream) { |
647 | #ifdef USE_NCCL |
648 | #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && \ |
649 | (NCCL_MAJOR * 10 + NCCL_MINOR) >= 27 |
650 | using namespace torch::cuda::nccl::detail; |
651 | |
652 | int numranks; |
653 | auto type = to_nccl_data_type(input); |
654 | size_t count = input.numel() / size; |
655 | size_t rankdiff = input.nbytes() / size; |
656 | const auto* sendbuff = reinterpret_cast<char*>(input.data_ptr()); |
657 | auto* recvbuff = reinterpret_cast<char*>(output.data_ptr()); |
658 | auto comm = to_nccl_comm(_comm); |
659 | #if defined(USE_ROCM) && ROCM_VERSION >= 50000 |
660 | NCCL_CHECK(ncclAllToAll(sendbuff, recvbuff, count, type, comm, stream)); |
661 | #else |
662 | NCCL_CHECK(ncclCommCount(comm, &numranks)); |
663 | NCCL_CHECK(ncclGroupStart()); |
664 | for (const auto r : c10::irange(numranks)) { |
665 | // NCCL uses 0 byte message for synchronization |
666 | // Avoid send/recv when message size is zero |
667 | if (count != 0) { |
668 | NCCL_CHECK( |
669 | ncclSend(sendbuff + r * rankdiff, count, type, r, comm, stream)); |
670 | NCCL_CHECK( |
671 | ncclRecv(recvbuff + r * rankdiff, count, type, r, comm, stream)); |
672 | } |
673 | } |
674 | NCCL_CHECK(ncclGroupEnd()); |
675 | #endif |
676 | #else |
677 | AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0" ); |
678 | #endif |
679 | #else |
680 | AT_ERROR("PyTorch built without NCCL support" ); |
681 | #endif |
682 | } |
683 | |
684 | void all2all_single_unequal_split( |
685 | void* sendbuff, |
686 | const size_t* sendcounts, |
687 | const size_t* senddispls, |
688 | void* recvbuff, |
689 | const size_t* recvcounts, |
690 | const size_t* recvdispls, |
691 | size_t size, |
692 | c10::ScalarType _type, |
693 | ncclComm_t _comm, |
694 | at::cuda::CUDAStream& stream) { |
695 | #ifdef USE_NCCL |
696 | #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && \ |
697 | (NCCL_MAJOR * 10 + NCCL_MINOR) >= 27 |
698 | using namespace torch::cuda::nccl::detail; |
699 | |
700 | auto type = to_nccl_data_type(_type); |
701 | auto comm = to_nccl_comm(_comm); |
702 | int numranks; |
703 | NCCL_CHECK(ncclCommCount(comm, &numranks)); |
704 | NCCL_CHECK(ncclGroupStart()); |
705 | for (const auto r : c10::irange(numranks)) { |
706 | // NCCL uses 0 byte message for synchronization |
707 | // Avoid send/recv when message size is zero |
708 | if (sendcounts[r] != 0) { |
709 | NCCL_CHECK(ncclSend( |
710 | ((char*)sendbuff) + senddispls[r] * size, |
711 | sendcounts[r], |
712 | type, |
713 | r, |
714 | comm, |
715 | stream)); |
716 | } |
717 | if (recvcounts[r] != 0) { |
718 | NCCL_CHECK(ncclRecv( |
719 | ((char*)recvbuff) + recvdispls[r] * size, |
720 | recvcounts[r], |
721 | type, |
722 | r, |
723 | comm, |
724 | stream)); |
725 | } |
726 | } |
727 | NCCL_CHECK(ncclGroupEnd()); |
728 | #else |
729 | AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0" ); |
730 | #endif |
731 | #else |
732 | AT_ERROR("PyTorch built without NCCL support" ); |
733 | #endif |
734 | } |
735 | |
736 | void all2all( |
737 | std::vector<at::Tensor>& outputTensors, |
738 | std::vector<at::Tensor>& inputTensors, |
739 | ncclComm_t _comm, |
740 | at::cuda::CUDAStream& stream) { |
741 | #ifdef USE_NCCL |
742 | #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && \ |
743 | (NCCL_MAJOR * 10 + NCCL_MINOR) >= 27 |
744 | using namespace torch::cuda::nccl::detail; |
745 | auto comm = to_nccl_comm(_comm); |
746 | |
747 | NCCL_CHECK(ncclGroupStart()); |
748 | for (const auto r : c10::irange(outputTensors.size())) { |
749 | at::Tensor& input = inputTensors[r]; |
750 | at::Tensor& output = outputTensors[r]; |
751 | if (input.numel() != 0) { |
752 | NCCL_CHECK(ncclSend( |
753 | input.data_ptr(), |
754 | input.numel(), |
755 | to_nccl_data_type(input), |
756 | r, |
757 | comm, |
758 | stream.stream())); |
759 | } |
760 | if (output.numel() != 0) { |
761 | NCCL_CHECK(ncclRecv( |
762 | output.data_ptr(), |
763 | output.numel(), |
764 | to_nccl_data_type(output), |
765 | r, |
766 | comm, |
767 | stream.stream())); |
768 | } |
769 | } |
770 | NCCL_CHECK(ncclGroupEnd()); |
771 | #else |
772 | AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0" ); |
773 | #endif |
774 | #else |
775 | AT_ERROR("PyTorch built without NCCL support" ); |
776 | #endif |
777 | } |
778 | |
779 | void send( |
780 | const at::Tensor& input, |
781 | ncclComm_t comm, |
782 | at::cuda::CUDAStream stream, |
783 | int dst) { |
784 | #ifdef USE_NCCL |
785 | #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \ |
786 | (NCCL_MINOR >= 7) |
787 | using namespace torch::cuda::nccl::detail; |
788 | NCCL_CHECK(ncclSend( |
789 | input.data_ptr(), |
790 | input.numel(), |
791 | to_nccl_data_type(input), |
792 | dst, |
793 | to_nccl_comm(comm), |
794 | stream.stream())); |
795 | #else |
796 | AT_ERROR("Send is only supported for NCCL lib version >= 2.7.0" ); |
797 | #endif |
798 | #else |
799 | AT_ERROR("PyTorch built without NCCL support" ); |
800 | #endif |
801 | } |
802 | |
803 | void recv( |
804 | at::Tensor& output, |
805 | ncclComm_t comm, |
806 | at::cuda::CUDAStream stream, |
807 | int src) { |
808 | #ifdef USE_NCCL |
809 | #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \ |
810 | (NCCL_MINOR >= 7) |
811 | using namespace torch::cuda::nccl::detail; |
812 | NCCL_CHECK(ncclRecv( |
813 | output.data_ptr(), |
814 | output.numel(), |
815 | to_nccl_data_type(output), |
816 | src, |
817 | to_nccl_comm(comm), |
818 | stream.stream())); |
819 | #else |
820 | AT_ERROR("Recv is only supported for NCCL lib version >= 2.7.0" ); |
821 | #endif |
822 | #else |
823 | AT_ERROR("PyTorch built without NCCL support" ); |
824 | #endif |
825 | } |
826 | |
827 | void gather( |
828 | const at::Tensor& inputs, |
829 | std::vector<at::Tensor>& outputs, |
830 | ncclComm_t _comm, |
831 | at::cuda::CUDAStream& stream, |
832 | int32_t root) { |
833 | #ifdef USE_NCCL |
834 | #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && \ |
835 | (NCCL_MAJOR * 10 + NCCL_MINOR) >= 27 |
836 | using namespace torch::cuda::nccl::detail; |
837 | |
838 | auto comm = to_nccl_comm(_comm); |
839 | int numranks, cur_rank; |
840 | NCCL_CHECK(ncclCommCount(comm, &numranks)); |
841 | NCCL_CHECK(ncclCommUserRank(comm, &cur_rank)); |
842 | |
843 | size_t count = inputs.numel(); |
844 | auto type = to_nccl_data_type(inputs); |
845 | const auto* sendbuff = reinterpret_cast<char*>(inputs.data_ptr()); |
846 | |
847 | NCCL_CHECK(ncclGroupStart()); |
848 | |
849 | if (cur_rank == root) { |
850 | for (const auto r : c10::irange(numranks)) { |
851 | if (r != root) { |
852 | auto* recvbuff = reinterpret_cast<char*>(outputs[r].data_ptr()); |
853 | NCCL_CHECK(ncclRecv(recvbuff, count, type, r, comm, stream)); |
854 | } else { |
855 | // on its own rank, simply copy from the input |
856 | outputs[r].copy_(inputs); |
857 | } |
858 | } |
859 | } else { |
860 | NCCL_CHECK(ncclSend(sendbuff, count, type, root, comm, stream)); |
861 | } |
862 | NCCL_CHECK(ncclGroupEnd()); |
863 | |
864 | #else |
865 | AT_ERROR("gather is only supported for NCCL lib version >= 2.7.0" ); |
866 | #endif |
867 | #else |
868 | AT_ERROR("PyTorch built without NCCL support" ); |
869 | #endif |
870 | } |
871 | |
872 | void scatter( |
873 | const std::vector<at::Tensor>& inputs, |
874 | at::Tensor& outputs, |
875 | ncclComm_t _comm, |
876 | at::cuda::CUDAStream& stream, |
877 | int32_t root) { |
878 | #ifdef USE_NCCL |
879 | #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && \ |
880 | (NCCL_MAJOR * 10 + NCCL_MINOR) >= 27 |
881 | using namespace torch::cuda::nccl::detail; |
882 | |
883 | auto comm = to_nccl_comm(_comm); |
884 | int numranks, cur_rank; |
885 | NCCL_CHECK(ncclCommCount(comm, &numranks)); |
886 | NCCL_CHECK(ncclCommUserRank(comm, &cur_rank)); |
887 | |
888 | NCCL_CHECK(ncclGroupStart()); |
889 | if (cur_rank == root) { |
890 | for (const auto r : c10::irange(numranks)) { |
891 | if (r != root) { |
892 | size_t send_count = inputs[r].numel(); |
893 | auto send_type = to_nccl_data_type(inputs[r]); |
894 | const auto* sendbuff = reinterpret_cast<char*>(inputs[r].data_ptr()); |
895 | NCCL_CHECK(ncclSend(sendbuff, send_count, send_type, r, comm, stream)); |
896 | } else { |
897 | // on its own rank, simply copy it to the output |
898 | outputs.copy_(inputs[r]); |
899 | } |
900 | } |
901 | } else { |
902 | size_t recv_count = outputs.numel(); |
903 | auto recv_type = to_nccl_data_type(outputs); |
904 | auto* recvbuff = reinterpret_cast<char*>(outputs.data_ptr()); |
905 | NCCL_CHECK(ncclRecv(recvbuff, recv_count, recv_type, root, comm, stream)); |
906 | } |
907 | NCCL_CHECK(ncclGroupEnd()); |
908 | |
909 | #else |
910 | AT_ERROR("scatter is only supported for NCCL lib version >= 2.7.0" ); |
911 | #endif |
912 | #else |
913 | AT_ERROR("PyTorch built without NCCL support" ); |
914 | #endif |
915 | } |
916 | |
917 | } // namespace nccl |
918 | } // namespace cuda |
919 | } // namespace torch |
920 | |