1 | #include <torch/csrc/distributed/c10d/NCCLUtils.hpp> |
2 | #include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp> |
3 | #include <torch/csrc/distributed/c10d/UCCForNCCL.hpp> |
4 | #include <sstream> |
5 | |
6 | #ifdef USE_C10D_NCCL |
7 | |
8 | #include <exception> |
9 | #include <map> |
10 | #include <stdexcept> |
11 | #include <tuple> |
12 | #include <unordered_set> |
13 | #include <utility> |
14 | |
15 | #include <ATen/cuda/CUDAContext.h> |
16 | #include <c10/core/DeviceType.h> |
17 | #include <c10/cuda/CUDAGraphsC10Utils.h> |
18 | #include <c10/cuda/CUDAGuard.h> |
19 | #include <c10/util/CallOnce.h> |
20 | #include <c10/util/Exception.h> |
21 | #include <c10/util/Logging.h> |
22 | #include <c10/util/Optional.h> |
23 | #include <c10/util/irange.h> |
24 | #include <torch/csrc/distributed/c10d/ParamCommsUtils.hpp> |
25 | #include <torch/csrc/distributed/c10d/TraceUtils.h> |
26 | #include <torch/csrc/distributed/c10d/Utils.hpp> |
27 | |
28 | #include <torch/csrc/cuda/nccl.h> |
29 | |
30 | namespace c10d { |
31 | |
32 | constexpr const char* const kNCCLAbortedCommStoreKey = "NCCLABORTEDCOMM" ; |
33 | |
34 | namespace { |
35 | |
36 | #if defined(NCCL_MAJOR) && \ |
37 | ((NCCL_MAJOR > 2) || (NCCL_MAJOR == 2) && (NCCL_MINOR >= 10)) |
38 | #define NCCL_HAS_AVG 1 |
39 | #endif |
40 | |
41 | // NCCL op mapping |
42 | const std::map<ReduceOp::RedOpType, ncclRedOp_t> ncclOp = { |
43 | {ReduceOp::MIN, ncclMin}, |
44 | {ReduceOp::MAX, ncclMax}, |
45 | {ReduceOp::SUM, ncclSum}, |
46 | {ReduceOp::PRODUCT, ncclProd}, |
47 | #ifdef NCCL_HAS_AVG |
48 | {ReduceOp::AVG, ncclAvg}, |
49 | #endif |
50 | }; |
51 | |
52 | // NCCL type typing |
53 | std::map<at::ScalarType, ncclDataType_t> ncclDataType = { |
54 | {at::kChar, ncclInt8}, |
55 | {at::kByte, ncclUint8}, |
56 | {at::kFloat, ncclFloat}, |
57 | {at::kDouble, ncclDouble}, |
58 | {at::kInt, ncclInt32}, |
59 | {at::kLong, ncclInt64}, |
60 | {at::kHalf, ncclHalf}, |
61 | {at::kBool, ncclUint8}, |
62 | #if HAS_NCCL_BF16_DATATYPE |
63 | {at::kBFloat16, ncclBfloat16}, |
64 | #endif |
65 | }; |
66 | |
67 | // Helper function that gets the data type and issues error if not supported |
68 | ncclDataType_t getNcclDataType(at::ScalarType type) { |
69 | auto it = ncclDataType.find(type); |
70 | TORCH_CHECK( |
71 | it != ncclDataType.end(), |
72 | "Input tensor data type is not supported for NCCL process group: " , |
73 | type); |
74 | return it->second; |
75 | } |
76 | |
77 | #ifdef ENABLE_NCCL_PREMUL_SUM_SUPPORT |
78 | template <typename T, ncclDataType_t dataType> |
79 | ncclRedOpRAII unpackPreMulSum( |
80 | const ReduceOp& reduceOp, |
81 | const ncclComm_t& comm, |
82 | int dev_in_group) { |
83 | const auto* preMulSupplement = |
84 | reinterpret_cast<NCCLPreMulSumSupplement*>(reduceOp.supplement_.get()); |
85 | ncclRedOp_t preMulSum; |
86 | bool has_tensor = preMulSupplement->tensor_factor.defined(); |
87 | auto residence = has_tensor ? ncclScalarDevice : ncclScalarHostImmediate; |
88 | T* ptr_factor = |
89 | has_tensor ? preMulSupplement->tensor_factor.data_ptr<T>() : nullptr; |
90 | T scalar_factor = T(preMulSupplement->double_factor); |
91 | ncclRedOpCreatePreMulSum( |
92 | &preMulSum, |
93 | has_tensor ? ptr_factor : &scalar_factor, |
94 | dataType, |
95 | residence, |
96 | comm); |
97 | return ncclRedOpRAII(preMulSum, comm); |
98 | } |
99 | #endif |
100 | |
101 | ncclRedOpRAII getNcclReduceOp( |
102 | const ReduceOp& reduceOp, |
103 | at::Tensor& input, |
104 | const ncclDataType_t& dataType, |
105 | const ncclComm_t& comm, |
106 | int dev_in_group) { |
107 | try { |
108 | if (input.scalar_type() == at::kBool) { |
109 | if (reduceOp == ReduceOp::SUM) { |
110 | // For bool tensors, map sum to max, which both represent a bitwise or. |
111 | // This is to prevent overflow issues with sum, since we use uint8 to |
112 | // represent a bool (see ncclDataType mapping). |
113 | return ncclMax; |
114 | } |
115 | #ifdef NCCL_HAS_AVG |
116 | if (reduceOp == ReduceOp::AVG) { |
117 | TORCH_CHECK(false, "Cannot use ReduceOp.AVG with boolean inputs" ); |
118 | } |
119 | #endif |
120 | } |
121 | if (reduceOp == ReduceOp::PREMUL_SUM) { |
122 | #ifdef ENABLE_NCCL_PREMUL_SUM_SUPPORT |
123 | switch (dataType) { |
124 | case ncclHalf: |
125 | return unpackPreMulSum<at::Half, ncclHalf>( |
126 | reduceOp, comm, dev_in_group); |
127 | case ncclFloat: |
128 | return unpackPreMulSum<float, ncclFloat>( |
129 | reduceOp, comm, dev_in_group); |
130 | case ncclDouble: |
131 | return unpackPreMulSum<double, ncclDouble>( |
132 | reduceOp, comm, dev_in_group); |
133 | default: |
134 | TORCH_CHECK( |
135 | false, "PreMulSum Data type must be half, float, or double" ); |
136 | ncclRedOp_t unused; |
137 | return unused; |
138 | } |
139 | #else |
140 | TORCH_CHECK(false, "PreMulSum requires NCCL>=2.11.1" ); |
141 | #endif |
142 | } |
143 | return ncclOp.at(reduceOp); |
144 | } catch (const std::out_of_range& e) { |
145 | switch (reduceOp) { |
146 | case ReduceOp::AVG: |
147 | TORCH_CHECK( |
148 | false, |
149 | "AVG requires NCCL 2.10+. The current version is " , |
150 | NCCL_MAJOR, |
151 | "." , |
152 | NCCL_MINOR); |
153 | break; |
154 | case ReduceOp::BAND: |
155 | TORCH_CHECK(false, "Cannot use ReduceOp.BAND with NCCL" ); |
156 | break; |
157 | case ReduceOp::BOR: |
158 | TORCH_CHECK(false, "Cannot use ReduceOp.BOR with NCCL" ); |
159 | break; |
160 | case ReduceOp::BXOR: |
161 | TORCH_CHECK(false, "Cannot use ReduceOp.BXOR with NCCL" ); |
162 | break; |
163 | default: |
164 | TORCH_CHECK(false, "Unhandled ReduceOp" ); |
165 | break; |
166 | } |
167 | } |
168 | } |
169 | |
170 | // Get the deviceList String from the list of devices |
171 | std::string getKeyFromDevices(const std::vector<at::Device>& devices) { |
172 | std::string deviceList; |
173 | for (auto& device : devices) { |
174 | if (deviceList.empty()) { |
175 | deviceList = std::to_string(device.index()); |
176 | } else { |
177 | deviceList += "," + std::to_string(device.index()); |
178 | } |
179 | } |
180 | return deviceList; |
181 | } |
182 | |
183 | std::string getKeySendRecv(int myRank, int peer) { |
184 | int lowRank = myRank < peer ? myRank : peer; |
185 | int highRank = myRank < peer ? peer : myRank; |
186 | std::string sendRecvPair = |
187 | std::to_string(lowRank) + ":" + std::to_string(highRank); |
188 | return sendRecvPair; |
189 | } |
190 | |
191 | // Get the list of devices from list of tensors |
192 | std::vector<at::Device> getDeviceList(const std::vector<at::Tensor>& tensors) { |
193 | std::vector<at::Device> res; |
194 | res.reserve(tensors.size()); |
195 | for (auto& tensor : tensors) { |
196 | // tensors must all be on the same device, or all on distinct devices. |
197 | // The line below assumes that constraint has already been enforced |
198 | // (by check_gpu_tensors_same_device or |
199 | // check_gpu_tensors_different_devices). |
200 | if (res.size() == 0 || tensor.device() != res[0]) { |
201 | res.push_back(tensor.device()); |
202 | } |
203 | } |
204 | return res; |
205 | } |
206 | |
207 | // Return CUDA device with ordinal given by input rank. |
208 | at::Device getDeviceForRank(int rank) { |
209 | TORCH_CHECK(rank >= 0, "Invalid rank " , rank); |
210 | auto numGPUs = at::cuda::getNumGPUs(); |
211 | int16_t deviceIdx = static_cast<int16_t>(rank % numGPUs); |
212 | return at::Device(at::DeviceType::CUDA, deviceIdx); |
213 | } |
214 | |
215 | // [Sync Streams] Helper that lets the input ncclStreams to wait for the current |
216 | // stream. NCCL communications run on ncclStreams, but input tensors are |
217 | // allocated on different streams (i.e., current streams). Communications on |
218 | // ncclStreams cannot start before pending input tensor ops on current streams |
219 | // finish. Otherwise, ops on two streams might read/write same tensors |
220 | // concurrently. |
221 | // |
222 | // The synchronization above alone is not enough. We also need to make sure |
223 | // input tensors are not freed before their usages on ncclStreams finish. This |
224 | // can be achieved by calling c10::cuda::CUDACachingAllocator::recordStream, |
225 | // which remembers the usage stream (ncclStream), creates an event on the usage |
226 | // stream when GC attempts to free the input tensor, and delays GC until that |
227 | // event is done. |
228 | void syncStreams( |
229 | const std::vector<at::Device>& devices, |
230 | std::vector<at::cuda::CUDAEvent>& ncclEvents, |
231 | std::vector<at::cuda::CUDAStream>& ncclStreams) { |
232 | for (const auto i : c10::irange(devices.size())) { |
233 | at::cuda::CUDAStream& ncclStream = ncclStreams[i]; |
234 | at::cuda::CUDAEvent& ncclEvent = ncclEvents[i]; |
235 | ncclEvent.record(at::cuda::getCurrentCUDAStream(devices[i].index())); |
236 | ncclEvent.block(ncclStream); |
237 | } |
238 | } |
239 | |
240 | // Given a ncclUniqueId, convert it to a string representation that can be put |
241 | // in the store. |
242 | std::string buildNcclUniqueIdStr(const ncclUniqueId& ncclID) { |
243 | const uint8_t* bytes = reinterpret_cast<const uint8_t*>(&ncclID); |
244 | std::ostringstream oss; |
245 | for (const auto i : c10::irange(NCCL_UNIQUE_ID_BYTES)) { |
246 | oss << std::hex << static_cast<int>(bytes[i]); |
247 | } |
248 | return oss.str(); |
249 | } |
250 | |
251 | std::string getNcclAbortedCommStoreKey(const std::string ncclIdStr) { |
252 | return std::string(kNCCLAbortedCommStoreKey) + ":" + ncclIdStr; |
253 | } |
254 | |
255 | // Returns exception's what() given an exception_ptr instance. |
256 | std::string getExceptionMsgFromExceptionPtr( |
257 | const std::exception_ptr& exceptionPtr) { |
258 | TORCH_CHECK(exceptionPtr != nullptr); |
259 | try { |
260 | std::rethrow_exception(exceptionPtr); |
261 | } catch (const std::exception& e) { |
262 | return e.what(); |
263 | } catch (...) { |
264 | return "Unknown exception type" ; |
265 | } |
266 | } |
267 | |
268 | inline void errorIfCapturingNonCapturableNCCL() { |
269 | auto status = c10::cuda::currentStreamCaptureStatusMayInitCtx(); |
270 | // parentheses avoid some compiler warnings |
271 | static const uint64_t min_version = |
272 | (((uint64_t)2) << 32) + (((uint64_t)9) << 16) + ((uint64_t)6); |
273 | static const uint64_t cur_version = torch::cuda::nccl::version(); |
274 | if (cur_version < min_version) { |
275 | TORCH_CHECK( |
276 | status == c10::cuda::CaptureStatus::None, |
277 | "Capturing NCCL collectives is only allowed with NCCL >= 2.9.6" ); |
278 | } |
279 | } |
280 | |
281 | } // namespace |
282 | |
283 | const int64_t ProcessGroupNCCL::kWatchdogThreadSleepMillis = 10000; |
284 | const int64_t ProcessGroupNCCL::kWorkCleanupThreadSleepMillis = 1000; |
285 | constexpr int64_t kWaitForAbortCommStoreKey = 1000; |
286 | constexpr int64_t kSynchronizeBusyWaitMillis = 10; |
287 | thread_local uint64_t ProcessGroupNCCL::ncclActiveGroupCounter_ = 0; |
288 | |
289 | std::ostream& operator<<( |
290 | std::ostream& output, |
291 | const ProcessGroupNCCL::WorkNCCL& workNCCL) { |
292 | std::string workInfo; |
293 | if (workNCCL.outputs_) { |
294 | workInfo = c10::str( |
295 | "WorkNCCL(" , |
296 | "SeqNum=" , |
297 | workNCCL.seq_, |
298 | ", OpType=" , |
299 | opTypeToString(workNCCL.opType_), |
300 | ", TensorShape=" , |
301 | (*workNCCL.outputs_)[0].sizes(), |
302 | ", Timeout(ms)=" , |
303 | workNCCL.opTimeout_.count(), |
304 | ")" ); |
305 | } else { |
306 | workInfo = c10::str( |
307 | "WorkNCCL(" , |
308 | "SeqNum=" , |
309 | workNCCL.seq_, |
310 | ", OpType=" , |
311 | opTypeToString(workNCCL.opType_), |
312 | ", Timeout(ms)=" , |
313 | workNCCL.opTimeout_.count(), |
314 | ")" ); |
315 | } |
316 | return output << workInfo; |
317 | } |
318 | |
319 | ProcessGroupNCCL::WorkNCCL::WorkNCCL( |
320 | const std::vector<at::Device>& devices, |
321 | int rank, |
322 | OpType opType, |
323 | uint64_t seq, |
324 | const char* profilingTitle, |
325 | const c10::optional<std::vector<at::Tensor>>& inputs, |
326 | bool desyncDebug) |
327 | : Work(rank, opType, profilingTitle, inputs), |
328 | devices_(devices), |
329 | workStartTime_(std::chrono::steady_clock::now()), |
330 | seq_(seq) { |
331 | // Creates the CUDA event wrappers |
332 | // Note: The actual events are lazily created when first recorded to with |
333 | // DEFAULT_FLAGS = cudaEventDisableTiming. |
334 | if (desyncDebug) { |
335 | ncclStartEvents_ = |
336 | std::make_shared<std::vector<at::cuda::CUDAEvent>>(devices.size()); |
337 | } |
338 | ncclEndEvents_ = |
339 | std::make_shared<std::vector<at::cuda::CUDAEvent>>(devices.size()); |
340 | ncclComms_.resize(devices.size()); |
341 | } |
342 | |
343 | ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w) |
344 | : Work(w.rank_, w.opType_), |
345 | std::enable_shared_from_this<WorkNCCL>(w), |
346 | devices_(w.devices_), |
347 | ncclStartEvents_(w.ncclStartEvents_), |
348 | ncclEndEvents_(w.ncclEndEvents_), |
349 | ncclComms_(w.ncclComms_), |
350 | blockingWait_(w.blockingWait_), |
351 | opTimeout_(w.opTimeout_), |
352 | workStartTime_(w.workStartTime_), |
353 | seq_(w.seq_), |
354 | startTraceUpdated_(w.startTraceUpdated_), |
355 | store_(w.store_) { |
356 | exception_ = w.exception_; |
357 | } |
358 | |
359 | ProcessGroupNCCL::WorkNCCL::~WorkNCCL() = default; |
360 | |
361 | bool ProcessGroupNCCL::WorkNCCL::isCompleted() { |
362 | checkAndSetException(); |
363 | return exception() || finishedGPUExecutionInternal(); |
364 | } |
365 | |
366 | bool ProcessGroupNCCL::WorkNCCL::isStarted() { |
367 | checkAndSetException(); |
368 | return exception() || startedGPUExecutionInternal(); |
369 | } |
370 | |
371 | bool ProcessGroupNCCL::WorkNCCL::isSuccess() const { |
372 | if (exception()) { |
373 | // Already detected an exception. |
374 | return false; |
375 | } |
376 | |
377 | return !checkForNCCLErrors(ncclComms_) && finishedGPUExecutionInternal(); |
378 | } |
379 | |
380 | void ProcessGroupNCCL::WorkNCCL::checkAndSetException() { |
381 | if (exception()) { |
382 | // We already have an exception. |
383 | return; |
384 | } |
385 | |
386 | auto exception_ptr = checkForNCCLErrors(ncclComms_); |
387 | std::unique_lock<std::mutex> lock(mutex_); |
388 | exception_ = exception_ptr; |
389 | if (exception_) { |
390 | LOG(INFO) << "[Rank " << rank_ << "]" |
391 | << " found async exception when checking for NCCL errors: " |
392 | << getExceptionMsgFromExceptionPtr(exception_); |
393 | } |
394 | } |
395 | |
396 | void ProcessGroupNCCL::WorkNCCL::setException( |
397 | std::exception_ptr exception_ptr) { |
398 | std::unique_lock<std::mutex> lock(mutex_); |
399 | exception_ = exception_ptr; |
400 | } |
401 | |
402 | // Helper that checks if the NCCL kernels are completed on the GPUs |
403 | bool ProcessGroupNCCL::WorkNCCL::finishedGPUExecution() { |
404 | checkAndSetException(); |
405 | return finishedGPUExecutionInternal(); |
406 | } |
407 | |
408 | bool ProcessGroupNCCL::WorkNCCL::startedGPUExecutionInternal() const { |
409 | for (const auto i : c10::irange(devices_.size())) { |
410 | // Checking the work's corresponding CUDA events' status |
411 | if (!(*ncclStartEvents_)[i].query()) { |
412 | return false; |
413 | } |
414 | } |
415 | return true; |
416 | } |
417 | |
418 | bool ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const { |
419 | try { |
420 | for (const auto i : c10::irange(devices_.size())) { |
421 | // Checking the work's corresponding CUDA events' status |
422 | if (!(*ncclEndEvents_)[i].query()) { |
423 | return false; |
424 | } |
425 | } |
426 | } catch (const std::exception& e) { |
427 | if (std::string(e.what()).find("driver shutting down" ) == |
428 | std::string::npos) { |
429 | throw; |
430 | } |
431 | LOG(INFO) << "[Rank " << rank_ |
432 | << "] Event query failed with exception: " << e.what(); |
433 | } |
434 | return true; |
435 | } |
436 | |
437 | void ProcessGroupNCCL::WorkNCCL::checkAndThrowException() { |
438 | // Set the appropriate exception if found. |
439 | checkAndSetException(); |
440 | |
441 | // Throw an exception, only if we have a valid exception. |
442 | if (exception()) { |
443 | std::rethrow_exception(exception()); |
444 | } |
445 | } |
446 | |
447 | void ProcessGroupNCCL::WorkNCCL::handleNCCLGuard( |
448 | ErrorHandlingMode asyncErrorHandling) { |
449 | std::lock_guard<std::mutex> lock(mutex_); |
450 | if (exception_) { |
451 | auto exceptionMsg = c10::str( |
452 | "Some NCCL operations have failed or timed out. Due to the " , |
453 | "asynchronous nature of CUDA kernels, subsequent GPU operations " , |
454 | "might run on corrupted/incomplete data." ); |
455 | LOG(ERROR) << exceptionMsg; |
456 | C10_LOG_API_USAGE_ONCE("ProcessGroupNCCL.WorkNCCL.handleNCCLGuard" ); |
457 | if (asyncErrorHandling == TearDown) { |
458 | auto tearDownMsg = c10::str( |
459 | "To avoid data inconsistency, we are taking the entire process down." ); |
460 | LOG(ERROR) << tearDownMsg; |
461 | std::rethrow_exception(exception_); |
462 | } |
463 | } |
464 | } |
465 | |
466 | void ProcessGroupNCCL::WorkNCCL::synchronize() { |
467 | // Call Synchronize without a timeout. We use this method to avoid adding a |
468 | // timeout argument to the public synchronize API. |
469 | synchronizeInternal(kNoTimeout); |
470 | } |
471 | |
472 | void ProcessGroupNCCL::WorkNCCL::synchronizeStreams() { |
473 | for (const auto i : c10::irange(devices_.size())) { |
474 | auto currentStream = at::cuda::getCurrentCUDAStream(devices_[i].index()); |
475 | // Block the current stream on the NCCL stream |
476 | (*ncclEndEvents_)[i].block(currentStream); |
477 | } |
478 | } |
479 | |
480 | // Waiting on the work's corresponding CUDA events |
481 | void ProcessGroupNCCL::WorkNCCL::synchronizeInternal( |
482 | std::chrono::milliseconds timeout) { |
483 | synchronizeStreams(); |
484 | |
485 | // In case of blocking, wait for the operation to complete. |
486 | if (blockingWait_) { |
487 | // Wait for the operation to complete. |
488 | while (!isCompleted()) { |
489 | if (timedOut()) { |
490 | // When operation times out due to some errors that are not |
491 | // detected by nccl communicators, ncclCommWatchdog can not check this |
492 | // time out error and thus can not abort ncclComms accordingly. |
493 | // So explicitly abort ncclComms here before throwing this timed out |
494 | // exception to users, after this, ncclCommWatchdog can detect nccl |
495 | // communicators are aborted and clean up devNCCLCommMap_ accordingly. |
496 | // if throwing timed out excepiton without aborting nccl communicators |
497 | // here, it was observed that CUDA GPU will have 100% utilization and |
498 | // can not run new events successfully. |
499 | |
500 | std::stringstream ss; |
501 | ss << *this; |
502 | auto timeoutErrorMsg = |
503 | c10::str("Work " , ss.str(), " timed out in call to wait()." ); |
504 | for (const auto& ncclComm : ncclComms_) { |
505 | ncclComm->ncclCommAbort(timeoutErrorMsg); |
506 | const auto& storeKey = getNcclAbortedCommStoreKey( |
507 | buildNcclUniqueIdStr(ncclComm->getNcclId())); |
508 | auto rankStr = std::to_string(rank_); |
509 | store_->set(storeKey, rankStr); |
510 | LOG(INFO) << "[Rank " << rank_ |
511 | << "] Wrote aborted communicator id to store: " << storeKey; |
512 | } |
513 | auto currentTimepoint = std::chrono::steady_clock::now(); |
514 | auto timeElapsed = |
515 | std::chrono::duration_cast<std::chrono::milliseconds>( |
516 | currentTimepoint - workStartTime_); |
517 | std::string exceptionMsg = c10::str( |
518 | "[Rank " , |
519 | rank_, |
520 | "] " , |
521 | "Caught collective operation timeout: " , |
522 | (*this), |
523 | " ran for " , |
524 | timeElapsed.count(), |
525 | " milliseconds before timing out." ); |
526 | TORCH_CHECK(false, exceptionMsg); |
527 | } |
528 | // Check for errors and throw appropriate exception. |
529 | checkAndThrowException(); |
530 | std::this_thread::sleep_for( |
531 | std::chrono::milliseconds(kSynchronizeBusyWaitMillis)); |
532 | } |
533 | checkAndThrowException(); |
534 | } |
535 | |
536 | // Device synchronize only after we've completed timeout checks. |
537 | if (!barrierTensors_.empty()) { |
538 | // If we use the work to do barrier, we should block here |
539 | for (auto& device : devices_) { |
540 | at::cuda::CUDAGuard gpuGuard(device); |
541 | AT_CUDA_CHECK(cudaDeviceSynchronize()); |
542 | } |
543 | } |
544 | } |
545 | |
546 | // Same as calling synchronize(). |
547 | bool ProcessGroupNCCL::WorkNCCL::wait(std::chrono::milliseconds timeout) { |
548 | RECORD_PARAM_COMMS( |
549 | static_cast<int>(this->seq_), // seq |
550 | 0, // process group ptr |
551 | rank_, // rank |
552 | "wait" , // colName |
553 | 0, // inSize |
554 | 0, // outSize |
555 | at::kByte, // dType |
556 | std::vector<int64_t>(), // inSplitSizes |
557 | std::vector<int64_t>()); // outSplitSizes |
558 | synchronizeInternal(timeout); |
559 | // Always return true, because abort API is not implemented. |
560 | return true; |
561 | } |
562 | |
563 | void ProcessGroupNCCL::WorkNCCL::abort() { |
564 | TORCH_CHECK(false, "ProcessGroupNCCL::WorkNCCL::abort not implemented." ); |
565 | } |
566 | |
567 | bool ProcessGroupNCCL::WorkNCCL::timedOut() { |
568 | auto currentTimepoint = std::chrono::steady_clock::now(); |
569 | return ( |
570 | std::chrono::duration_cast<std::chrono::milliseconds>( |
571 | currentTimepoint - workStartTime_) >= opTimeout_); |
572 | } |
573 | |
574 | ProcessGroupNCCL::CoalescedWorkNCCL::CoalescedWorkNCCL( |
575 | std::vector<ProcessGroupNCCL::WorkNCCL> works, |
576 | int rank, |
577 | OpType opType) |
578 | : Work(rank, opType, nullptr), works_(std::move(works)) {} |
579 | |
580 | ProcessGroupNCCL::CoalescedWorkNCCL::~CoalescedWorkNCCL() = default; |
581 | |
582 | c10::intrusive_ptr<ProcessGroupNCCL::CoalescedWorkNCCL> ProcessGroupNCCL:: |
583 | initCoalescedWork( |
584 | const std::vector<c10::intrusive_ptr<Work>>& works, |
585 | int rank, |
586 | OpType opType) { |
587 | std::vector<ProcessGroupNCCL::WorkNCCL> ncclWorks; |
588 | ncclWorks.reserve(works.size()); |
589 | for (auto& work : works) { |
590 | ncclWorks.push_back(*static_cast<ProcessGroupNCCL::WorkNCCL*>(work.get())); |
591 | } |
592 | return c10::make_intrusive<ProcessGroupNCCL::CoalescedWorkNCCL>( |
593 | ncclWorks, rank, opType); |
594 | } |
595 | |
596 | // Same as calling synchronize(). |
597 | bool ProcessGroupNCCL::CoalescedWorkNCCL::wait( |
598 | std::chrono::milliseconds timeout) { |
599 | for (auto& w : works_) { |
600 | w.wait(timeout); |
601 | } |
602 | // Always return true, because abort API is not implemented. |
603 | return true; |
604 | } |
605 | |
606 | ProcessGroupNCCL::ProcessGroupNCCL( |
607 | const c10::intrusive_ptr<Store>& store, |
608 | int rank, |
609 | int size, |
610 | c10::intrusive_ptr<Options> options) |
611 | : Backend(rank, size), |
612 | store_(store), |
613 | options_(options), |
614 | ncclCommCounter_(0), |
615 | traceKeyStart_(getTraceStartKey("NCCL" , rank)), |
616 | traceKeyEnd_(getTraceEndKey("NCCL" , rank)), |
617 | terminateProcessGroup_(false) { |
618 | TORCH_CHECK( |
619 | at::cuda::getNumGPUs() != 0, |
620 | "ProcessGroupNCCL is only supported with GPUs, no GPUs found!" ); |
621 | blockingWait_ = parseEnvVarFlag(NCCL_BLOCKING_WAIT); |
622 | asyncErrorHandling_ = static_cast<ErrorHandlingMode>( |
623 | parseEnvVarIntDefault(NCCL_ASYNC_ERROR_HANDLING, 0)); |
624 | desyncDebug_ = parseEnvVarFlag(NCCL_DESYNC_DEBUG) || |
625 | (dist_debug_level_ >= DebugLevel::Detail); |
626 | |
627 | if (blockingWait_) { |
628 | if (asyncErrorHandling_ != NoHandling || desyncDebug_) { |
629 | LOG(INFO) << "[Rank " << rank_ << "] NCCL_BLOCKING_WAIT and " |
630 | << "NCCL_ASYNC_ERROR_HANDLING|NCCL_DESYNC_DEBUG" |
631 | << "should not both be enabled. " |
632 | << "Only NCCL_BLOCKING_WAIT is being used in this process." ; |
633 | asyncErrorHandling_ = NoHandling; |
634 | desyncDebug_ = false; |
635 | } |
636 | } else { |
637 | if (desyncDebug_ && asyncErrorHandling_ == NoHandling) { |
638 | LOG(INFO) << "[Rank " << rank_ |
639 | << "] NCCL_DESYNC_DEBUG and NCCL_ASYNC_ERROR_HANDLING " |
640 | << "must both be enabled. " |
641 | << "Enabling NCCL_ASYNC_ERROR_HANDLING." ; |
642 | asyncErrorHandling_ = TearDown; |
643 | } |
644 | } |
645 | |
646 | if (parseEnvVarFlag(ENABLE_NCCL_HEALTH_CHECK)) { |
647 | // Perform health check by initializing dummy communicators and destroying |
648 | // them. This will help indicate any NCCL-related issues prior to the first |
649 | // collective. |
650 | // Run it in a separate thread and wait on CV to handle timeouts, since |
651 | // majority of getNCCLComm failures are hangs. |
652 | runHealthCheck(); |
653 | } |
654 | |
655 | #ifdef ENABLE_NCCL_ERROR_CHECKING |
656 | ncclCommWatchdogThread_ = |
657 | std::thread(&ProcessGroupNCCL::ncclCommWatchdog, this); |
658 | #endif |
659 | |
660 | if (asyncErrorHandling_ != NoHandling) { |
661 | workCleanupThread_ = std::thread(&ProcessGroupNCCL::workCleanupLoop, this); |
662 | } |
663 | |
664 | init(); |
665 | LOG(INFO) << "[Rank " << rank_ |
666 | << "] ProcessGroupNCCL initialized with following options:" |
667 | << "\nNCCL_ASYNC_ERROR_HANDLING: " << asyncErrorHandling_ |
668 | << "\nNCCL_DESYNC_DEBUG: " << desyncDebug_ |
669 | << "\nNCCL_BLOCKING_WAIT: " << blockingWait_ |
670 | << "\nTIMEOUT(ms): " << options_->timeout.count() |
671 | << "\nUSE_HIGH_PRIORITY_STREAM: " |
672 | << options_->is_high_priority_stream; |
673 | |
674 | RECORD_PARAM_COMMS( |
675 | 0, // seq |
676 | reinterpret_cast<std::intptr_t>(this), // process group ptr |
677 | rank, // rank |
678 | "init" , // colName |
679 | 0, // inSize |
680 | 0, // outSize |
681 | at::kByte, // dType |
682 | std::vector<int64_t>(), // inSplitSizes |
683 | std::vector<int64_t>()); // outSplitSizes |
684 | |
685 | #ifdef USE_NCCL_WITH_UCC |
686 | static c10::once_flag initialize_ucc_lib_flag; |
687 | c10::call_once(initialize_ucc_lib_flag, [&] { |
688 | uccLib_ = loadTorchUCC(); |
689 | if (uccLib_ != nullptr) { |
690 | LOG(INFO) << "[Rank " << rank_ << "] torch_ucc.so loaded" ; |
691 | } |
692 | }); |
693 | |
694 | if (uccLib_ != nullptr) { |
695 | LOG(INFO) << "[Rank " << rank_ << "] torch_ucc.so loaded" ; |
696 | typedef c10::intrusive_ptr<Backend> fn( |
697 | const c10::intrusive_ptr<Store>& store, int rank, int size); |
698 | auto createProcessGroupUCC = |
699 | reinterpret_cast<fn*>(uccLib_->sym("createProcessGroupUCC" )); |
700 | if (createProcessGroupUCC != nullptr) { |
701 | uccPG_ = createProcessGroupUCC(store, rank_, size_); |
702 | LOG(INFO) << "[Rank " << rank_ << "] ProcessGroupUCC created." ; |
703 | } |
704 | } |
705 | #endif |
706 | } |
707 | |
708 | void ProcessGroupNCCL::runHealthCheck() { |
709 | // Run health check in a separate thread and wait on CV to handle timeouts, |
710 | // since majority of getNCCLComm failures are hangs. |
711 | |
712 | struct HealthCheckData { |
713 | std::mutex healthCheckMutex; |
714 | std::condition_variable healthCheckCv; |
715 | bool healthCheckSuccess = false; |
716 | std::exception_ptr healthCheckException; |
717 | }; |
718 | |
719 | HealthCheckData healthCheckData; |
720 | auto t = std::thread([&healthCheckData, this]() { |
721 | try { |
722 | std::vector<at::Device> rankDevice = {getDeviceForRank(rank_)}; |
723 | const auto key = getKeyFromDevices(rankDevice); |
724 | // OpType does not matter, only need to set to not go through send/recv |
725 | // path. |
726 | getNCCLComm(key, rankDevice, OpType::ALLREDUCE); |
727 | // Now destroy the communicators and remove them from cache so we don't |
728 | // use destroyed communicators. |
729 | destroyNCCLComms(key); |
730 | // Notify main thread the health check is complete. |
731 | { |
732 | std::lock_guard<std::mutex> lk(healthCheckData.healthCheckMutex); |
733 | healthCheckData.healthCheckSuccess = true; |
734 | } |
735 | healthCheckData.healthCheckCv.notify_one(); |
736 | } catch (const std::exception& e) { |
737 | // Populate exception ptr. |
738 | healthCheckData.healthCheckException = std::current_exception(); |
739 | // Unblock waiting main thread which will report exception. |
740 | healthCheckData.healthCheckCv.notify_one(); |
741 | } // Unknown exceptions will just cause the program to terminate. |
742 | }); |
743 | // We don't need to join the thread, just need to verify health check via the |
744 | // CV. Hence we detach the thread here. |
745 | t.detach(); // NOLINT |
746 | LOG(INFO) << "[Rank " << rank_ << "]" |
747 | << " will wait up to " << options_->timeout.count() |
748 | << " msec for NCCL health check to complete." ; |
749 | std::unique_lock<std::mutex> lock(healthCheckData.healthCheckMutex); |
750 | healthCheckData.healthCheckCv.wait_for( |
751 | lock, options_->timeout, [&healthCheckData]() { |
752 | return healthCheckData.healthCheckSuccess; |
753 | }); |
754 | |
755 | if (healthCheckData.healthCheckException) { |
756 | std::rethrow_exception(healthCheckData.healthCheckException); |
757 | } |
758 | // If there is no exception, the likely culprit is a timeout/hang which is how |
759 | // most communicator init issues manifest themselves. |
760 | TORCH_CHECK( |
761 | healthCheckData.healthCheckSuccess, |
762 | "ProcessGroupNCCL: Health check failure: Failed to initialize NCCL communicator on rank " , |
763 | rank_); |
764 | } |
765 | |
766 | void ProcessGroupNCCL::setSequenceNumberForGroup() {} |
767 | |
768 | uint64_t ProcessGroupNCCL::getSequenceNumberForGroup() { |
769 | return seq_; |
770 | } |
771 | |
772 | ProcessGroupNCCL::~ProcessGroupNCCL() { |
773 | terminateProcessGroup_.store(true); |
774 | |
775 | watchdogCV_.notify_one(); |
776 | #ifdef ENABLE_NCCL_ERROR_CHECKING |
777 | ncclCommWatchdogThread_.join(); |
778 | #endif |
779 | |
780 | if (asyncErrorHandling_ != NoHandling) { |
781 | workMetaListCV_.notify_one(); |
782 | workCleanupThread_.join(); |
783 | } |
784 | |
785 | { |
786 | // Abort all NCCL Communicators on Process Group Destruction |
787 | std::lock_guard<std::mutex> lock(mutex_); |
788 | for (auto& it : devNCCLCommMap_) { |
789 | auto& ncclComms = it.second; |
790 | |
791 | for (const auto& ncclComm : ncclComms) { |
792 | std::string abortReason = |
793 | c10::str("Process Group destroyed on rank " , rank_); |
794 | ncclComm->ncclCommAbort(abortReason); |
795 | } |
796 | } |
797 | } |
798 | } |
799 | |
800 | void ProcessGroupNCCL::abortTimedOutCollectives( |
801 | std::unordered_set<std::string>& abortedCommIds) { |
802 | std::unique_lock<std::mutex> lock(workMetaListMutex_); |
803 | for (auto& work : workMetaList_) { |
804 | work.checkAndSetException(); |
805 | // Aborting NCCL Communicators due to errors is already handled above. |
806 | if (work.exception()) { |
807 | continue; |
808 | } |
809 | |
810 | // Check for Timeouts in the WorkNCCL Operations, and abort all |
811 | // communicators accordingly. |
812 | if (work.timedOut()) { |
813 | auto currentTimepoint = std::chrono::steady_clock::now(); |
814 | auto timeElapsed = std::chrono::duration_cast<std::chrono::milliseconds>( |
815 | currentTimepoint - work.workStartTime_); |
816 | std::string exceptionMsg = c10::str( |
817 | "[Rank " , |
818 | rank_, |
819 | "] " , |
820 | "Watchdog caught collective operation timeout: " , |
821 | work, |
822 | " ran for " , |
823 | timeElapsed.count(), |
824 | " milliseconds before timing out." ); |
825 | if (desyncDebug_) { |
826 | exceptionMsg += retrieveDesyncReport(store_, "NCCL" , rank_, size_); |
827 | } |
828 | LOG(ERROR) << exceptionMsg; |
829 | std::exception_ptr exception_ptr = |
830 | std::make_exception_ptr(std::runtime_error(exceptionMsg)); |
831 | work.setException(exception_ptr); |
832 | for (const auto& ncclComm : work.ncclComms_) { |
833 | ncclComm->ncclCommAbort(exceptionMsg); |
834 | abortedCommIds.emplace(buildNcclUniqueIdStr(ncclComm->getNcclId())); |
835 | } |
836 | } |
837 | } |
838 | } |
839 | |
840 | void ProcessGroupNCCL::ncclCommWatchdog() { |
841 | try { |
842 | LOG(INFO) << "[Rank " << rank_ << "] NCCL watchdog thread started!" ; |
843 | ncclCommWatchdogInternal(); |
844 | LOG(INFO) << "[Rank " << rank_ |
845 | << "] NCCL watchdog thread terminated normally" ; |
846 | } catch (std::exception& e) { |
847 | LOG(INFO) << "[Rank " << rank_ |
848 | << "] NCCL watchdog thread terminated with exception: " |
849 | << e.what(); |
850 | } catch (...) { |
851 | LOG(INFO) << "[Rank " << rank_ |
852 | << "] NCCL watchdog thread terminated with unknown exception" ; |
853 | } |
854 | } |
855 | |
856 | void ProcessGroupNCCL::ncclCommWatchdogInternal() { |
857 | while (!terminateProcessGroup_.load()) { |
858 | std::unordered_set<std::string> abortedCommIds; |
859 | std::unordered_set<std::string> allCommIds; |
860 | |
861 | { |
862 | // Loop through the cache of communicators for NCCL errors. |
863 | std::lock_guard<std::mutex> lock(mutex_); |
864 | for (auto& it : devNCCLCommMap_) { |
865 | auto& ncclComms = it.second; |
866 | |
867 | for (const auto& ncclComm : ncclComms) { |
868 | allCommIds.emplace(buildNcclUniqueIdStr(ncclComm->getNcclId())); |
869 | } |
870 | std::exception_ptr ncclErrorException = checkForNCCLErrors(ncclComms); |
871 | if (ncclErrorException) { |
872 | auto exceptionMsg = |
873 | getExceptionMsgFromExceptionPtr(ncclErrorException); |
874 | LOG(INFO) |
875 | << "[Rank " << rank_ |
876 | << "] Received NCCL errors for communicators in the cache: \n" |
877 | << "NCCL error: \n" |
878 | << exceptionMsg; |
879 | |
880 | if (blockingWait_ || asyncErrorHandling_ != NoHandling) { |
881 | LOG(INFO) << "[Rank " << rank_ |
882 | << "] Aborting communicators that received errors" ; |
883 | // We abort NCCL communicators that have received errors from this |
884 | // thread, and exceptions are set on the corresponding work objects. |
885 | // The workCleanupThread will then loop through the unfinished |
886 | // collectives and throw exceptions if an exception has been set on |
887 | // any of the work objects from this thread. |
888 | for (const auto& ncclComm : ncclComms) { |
889 | // We are aborting remaining communicators due to an error in |
890 | // at least one of these communicators, so propagate that reason |
891 | // for better debugability. |
892 | ncclComm->ncclCommAbort(exceptionMsg); |
893 | // Note that we don't remove the aborted communicators from the |
894 | // cache. The reason is that if we do remove the communicator |
895 | // from the cache, it is possible that a new collective operation |
896 | // calls `ncclCommInitRank` to create a new communicator whereas |
897 | // other ranks might have failed/timed out and didn't enter |
898 | // `ncclCommInitRank`. As a result, when there is a failure on |
899 | // a communicator the application receives an exception and its |
900 | // their responsibility to destroy the process group and recreate |
901 | // it to recover from errors. |
902 | abortedCommIds.emplace( |
903 | buildNcclUniqueIdStr(ncclComm->getNcclId())); |
904 | } |
905 | } |
906 | } |
907 | } |
908 | } |
909 | |
910 | if (asyncErrorHandling_ != NoHandling) { |
911 | abortTimedOutCollectives(abortedCommIds); |
912 | } |
913 | |
914 | if (blockingWait_) { |
915 | // When we abort a communicator on one rank, it is likely that might cause |
916 | // other ranks to hang indefinitely. As a result, whenever we abort a |
917 | // communicator, we write its ID to the store. The watchdog on other ranks |
918 | // then monitor the store, find an aborted communicator ID and abort their |
919 | // respective communicator as well. |
920 | |
921 | // Record the aborted communicators locally and in the store. |
922 | for (const auto& abortedCommId : abortedCommIds) { |
923 | abortedComms_.emplace(abortedCommId); |
924 | const auto& storeKey = getNcclAbortedCommStoreKey(abortedCommId); |
925 | auto rankStr = std::to_string(rank_); |
926 | store_->set(storeKey, rankStr); |
927 | LOG(INFO) << "[Rank " << rank_ |
928 | << "] Watchdog wrote aborted communicator id to store: " |
929 | << storeKey; |
930 | } |
931 | |
932 | // Check for any communicators in the store and abort them if needed. |
933 | for (const auto& commId : allCommIds) { |
934 | if (abortedComms_.find(commId) == abortedComms_.end()) { |
935 | // Check if we need to abort them if not already aborted (shouldn't |
936 | // wait more than the watchdog sleep time.). |
937 | const auto& storeKey = getNcclAbortedCommStoreKey(commId); |
938 | try { |
939 | store_->wait( |
940 | {storeKey}, |
941 | std::chrono::milliseconds(kWaitForAbortCommStoreKey)); |
942 | auto val = store_->get(storeKey); |
943 | std::string rank(reinterpret_cast<char*>(val.data()), val.size()); |
944 | std::stringstream ss; |
945 | ss << "[Rank " << rank_ << "] Found key in store: " << storeKey |
946 | << ", from rank: " << rank |
947 | << ". This means that rank has aborted its NCCL communicators previously and is not in a healthy state." |
948 | << ". Aborting appropriate communicators" ; |
949 | std::string abortReason = ss.str(); |
950 | LOG(WARNING) << abortReason; |
951 | |
952 | // Now abort the appropriate communicators. |
953 | std::lock_guard<std::mutex> lock(mutex_); |
954 | auto it = ncclIdToCommMap_.find(commId); |
955 | TORCH_INTERNAL_ASSERT(it != ncclIdToCommMap_.end()); |
956 | for (const auto& ncclComm : it->second) { |
957 | // The reason we are aborting is because some other ranks have |
958 | // aborted their communicators originally, so propagate that |
959 | // reason. |
960 | ncclComm->ncclCommAbort(abortReason); |
961 | } |
962 | abortedComms_.emplace(commId); |
963 | LOG(INFO) << "[Rank " << rank_ |
964 | << "] Aborted communicators for key in store: " |
965 | << storeKey; |
966 | } catch (std::exception& e) { |
967 | VLOG(1) << "Did not find key in store: " << storeKey |
968 | << ", error: " << e.what(); |
969 | } |
970 | } |
971 | } |
972 | } |
973 | |
974 | std::unique_lock<std::mutex> lock(watchdogCVMutex_); |
975 | watchdogCV_.wait_for( |
976 | lock, |
977 | std::chrono::milliseconds(kWatchdogThreadSleepMillis), |
978 | [&]() -> bool { return terminateProcessGroup_.load(); }); |
979 | } |
980 | } |
981 | |
982 | void ProcessGroupNCCL::workCleanupLoop() { |
983 | bool done = false; |
984 | while (!terminateProcessGroup_.load() || !done) { |
985 | std::list<WorkNCCL> doneWorks; |
986 | { |
987 | std::unique_lock<std::mutex> lock(workMetaListMutex_); |
988 | // We busy-poll the work vector every kWatchdogThreadSleepMillis |
989 | // milliseconds as long as the atomic is True. |
990 | workMetaListCV_.wait_for( |
991 | lock, |
992 | std::chrono::milliseconds(kWorkCleanupThreadSleepMillis), |
993 | [&]() -> bool { return terminateProcessGroup_.load(); }); |
994 | |
995 | for (auto it = workMetaList_.begin(); it != workMetaList_.end(); |
996 | /* no increment*/) { |
997 | auto& work = *it; |
998 | |
999 | if (desyncDebug_ && !work.exception()) { |
1000 | if (!work.startTraceUpdated_ && work.isStarted() && |
1001 | !terminateProcessGroup_.load() && !storeError_) { |
1002 | work.startTraceUpdated_ = true; |
1003 | storeError_ = !c10d::traceUpdate( |
1004 | store_, |
1005 | traceKeyStart_, |
1006 | work.seq_, |
1007 | opTypeToString(work.opType_)); |
1008 | } |
1009 | } |
1010 | |
1011 | if (work.isCompleted()) { |
1012 | if (desyncDebug_ && !work.exception()) { |
1013 | // To close the window between the check of work.isStarted() and |
1014 | // the check of work.isCompleted(). |
1015 | if (!work.startTraceUpdated_ && !terminateProcessGroup_.load() && |
1016 | !storeError_) { |
1017 | storeError_ = !c10d::traceUpdate( |
1018 | store_, |
1019 | traceKeyStart_, |
1020 | work.seq_, |
1021 | opTypeToString(work.opType_)); |
1022 | } |
1023 | if (!terminateProcessGroup_.load() && !storeError_) { |
1024 | storeError_ = !c10d::traceUpdate( |
1025 | store_, |
1026 | traceKeyEnd_, |
1027 | work.seq_, |
1028 | opTypeToString(work.opType_)); |
1029 | } |
1030 | } |
1031 | // Handle Exceptions on failed GPU operations and remove completed |
1032 | // workNCCL objects from work vector. |
1033 | if (!terminateProcessGroup_.load()) { |
1034 | work.handleNCCLGuard(asyncErrorHandling_); |
1035 | } |
1036 | doneWorks.push_back(std::move(*it)); |
1037 | it = workMetaList_.erase(it); |
1038 | } else { |
1039 | // Increment the iterator if the current WorkNCCL object is not |
1040 | // completed. |
1041 | ++it; |
1042 | } |
1043 | } |
1044 | done = workMetaList_.empty(); |
1045 | } |
1046 | doneWorks.clear(); |
1047 | } |
1048 | } |
1049 | |
1050 | std::exception_ptr ProcessGroupNCCL::WorkNCCL::checkForNCCLErrors( |
1051 | const std::vector<std::shared_ptr<NCCLComm>>& ncclComms) const { |
1052 | return checkForNCCLErrorsInternal(ncclComms); |
1053 | } |
1054 | |
1055 | std::exception_ptr ProcessGroupNCCL::checkForNCCLErrors( |
1056 | const std::vector<std::shared_ptr<NCCLComm>>& ncclComms) { |
1057 | return checkForNCCLErrorsInternal(ncclComms); |
1058 | } |
1059 | |
1060 | std::exception_ptr ProcessGroupNCCL::checkForNCCLErrorsInternal( |
1061 | const std::vector<std::shared_ptr<NCCLComm>>& ncclComms) { |
1062 | for (const auto& ncclComm : ncclComms) { |
1063 | // Prioritize commFailureReason over checkForNcclError() result if |
1064 | // commFailureReason is set. |
1065 | auto commFailureReason = ncclComm->getNcclCommFailureReason(); |
1066 | if (commFailureReason != c10::nullopt) { |
1067 | return std::make_exception_ptr(std::runtime_error(c10::str( |
1068 | "NCCL communicator encountered error set by ProcessGroupNCCL: " , |
1069 | *commFailureReason))); |
1070 | } |
1071 | ncclResult_t ncclAsyncErr = ncclComm->checkForNcclError(); |
1072 | if (ncclAsyncErr != ncclSuccess) { |
1073 | return std::make_exception_ptr(std::runtime_error( |
1074 | "NCCL error: " + ncclGetErrorWithVersion(ncclAsyncErr) + "\n" + |
1075 | getNcclErrorDetailStr(ncclAsyncErr))); |
1076 | } |
1077 | } |
1078 | |
1079 | return nullptr; |
1080 | } |
1081 | |
1082 | void ProcessGroupNCCL::broadcastUniqueNCCLID( |
1083 | ncclUniqueId* ncclID, |
1084 | bool isSingleP2POp, |
1085 | const std::string& p2pKey, |
1086 | int p2pRank) { |
1087 | // For collective operations: |
1088 | // For every NCCL communicator that we create we need to broadcast |
1089 | // a unique ID from rank 0 to all other ranks. This broadcast is |
1090 | // done by rank 0 setting a key in the store and all other ranks |
1091 | // retrieving the contents of that key. A single process group |
1092 | // may create multiple NCCL communicators, so we use a sequence |
1093 | // number to differentiate between them. |
1094 | // For single point-to-point operations: |
1095 | // The sequence number will only be increased on 2 out of all the |
1096 | // processes in a Process Group. So all following collective |
1097 | // operations will see different sequence numbers which will cause |
1098 | // runtime errors. To avoid that, use the src:target pair instead |
1099 | // of sequence number for p2p communications. |
1100 | |
1101 | std::string storeKey; |
1102 | if (!isSingleP2POp) { |
1103 | storeKey = std::to_string(ncclCommCounter_++); |
1104 | } else { |
1105 | storeKey = p2pKey; |
1106 | } |
1107 | if (rank_ == 0 || (isSingleP2POp && p2pRank == 0)) { |
1108 | auto vec = std::vector<uint8_t>( |
1109 | reinterpret_cast<uint8_t*>(ncclID), |
1110 | reinterpret_cast<uint8_t*>(ncclID) + NCCL_UNIQUE_ID_BYTES); |
1111 | store_->set(storeKey, vec); |
1112 | } else { |
1113 | try { |
1114 | auto vec = store_->get(storeKey); |
1115 | TORCH_CHECK(vec.size() == NCCL_UNIQUE_ID_BYTES); |
1116 | std::memcpy(ncclID, vec.data(), vec.size()); |
1117 | } catch (const std::exception& e) { |
1118 | std::string exceptionMsg = c10::str( |
1119 | "[" , |
1120 | rank_, |
1121 | "] is setting up NCCL communicator and " |
1122 | "retrieving ncclUniqueId from [0] via c10d key-value store by key '" , |
1123 | storeKey, |
1124 | "', but store->get('" , |
1125 | storeKey, |
1126 | "') got error: " ); |
1127 | TORCH_CHECK(false, exceptionMsg + e.what()); |
1128 | } catch (...) { |
1129 | TORCH_CHECK( |
1130 | false, |
1131 | c10::str( |
1132 | "Unknown exception while [" , |
1133 | rank_, |
1134 | "] is setting up NCCL communicator and " |
1135 | "retrieving ncclUniqueId from [0] via c10d key-value store by key '" , |
1136 | storeKey, |
1137 | "'" )); |
1138 | } |
1139 | } |
1140 | } |
1141 | |
1142 | void ProcessGroupNCCL::destroyNCCLComms(const std::string& devNCCLCommMapKey) { |
1143 | std::lock_guard<std::mutex> lock(mutex_); |
1144 | if (devNCCLCommMap_.find(devNCCLCommMapKey) == devNCCLCommMap_.end()) { |
1145 | TORCH_INTERNAL_ASSERT( |
1146 | false, |
1147 | "Expected to find key " , |
1148 | devNCCLCommMapKey, |
1149 | " in NCCL communicator map." ); |
1150 | } |
1151 | std::vector<std::shared_ptr<NCCLComm>>& ncclComms = |
1152 | devNCCLCommMap_[devNCCLCommMapKey]; |
1153 | // Loop through communicators and call ncclCommAbort. |
1154 | for (const auto& comm : ncclComms) { |
1155 | // ncclCommDestroy(comm->getNcclComm()) results in segfault when PG is being |
1156 | // destroyed, so using ncclCommAbort here. |
1157 | comm->ncclCommAbort(); |
1158 | } |
1159 | // Remove communicators from the cache. |
1160 | devNCCLCommMap_.erase(devNCCLCommMapKey); |
1161 | // Clear used device indices. |
1162 | usedDeviceIdxs_.clear(); |
1163 | } |
1164 | |
1165 | std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm( |
1166 | const std::string& devicesKey, |
1167 | const std::vector<at::Device>& devices, |
1168 | OpType opType, |
1169 | int p2pRank, |
1170 | bool isSendRecvSelf) { |
1171 | // Sanity check |
1172 | if (devicesKey.empty()) { |
1173 | TORCH_CHECK( |
1174 | false, |
1175 | "Not able to create/get the NCCL Communicator since " |
1176 | "the GPU devices are not known" ); |
1177 | } |
1178 | |
1179 | for (auto& device : devices) { |
1180 | usedDeviceIdxs_.insert(device.index()); |
1181 | } |
1182 | |
1183 | { |
1184 | std::lock_guard<std::mutex> lock(mutex_); |
1185 | if (devNCCLCommMap_.find(devicesKey) != devNCCLCommMap_.end()) { |
1186 | // Reuse the cached communicator if there is one. |
1187 | return devNCCLCommMap_[devicesKey]; |
1188 | } |
1189 | } |
1190 | |
1191 | // NCCL communicator not cached, create a new entry |
1192 | std::vector<std::shared_ptr<NCCLComm>> ncclComms; |
1193 | ncclComms.resize(devices.size()); |
1194 | |
1195 | // Create the unique NCCL ID and broadcast it |
1196 | ncclUniqueId ncclID; |
1197 | |
1198 | // For batch_isend_irecv, ncclGroupStart() would be called upfront |
1199 | bool batchP2P = ncclActiveGroupCounter_ > 0; |
1200 | bool singleP2POp = isP2POp(opType, batchP2P); |
1201 | // For point-to-point communication, lower rank of the two will get unique id. |
1202 | if (rank_ == 0 || (singleP2POp && p2pRank == 0)) { |
1203 | C10D_NCCL_CHECK(ncclGetUniqueId(&ncclID), c10::nullopt); |
1204 | } |
1205 | |
1206 | // For point-to-point communication on the same process, don't need broadcast. |
1207 | if (!isSendRecvSelf) { |
1208 | // Broadcast so that each process can have a unique NCCL ID |
1209 | broadcastUniqueNCCLID(&ncclID, singleP2POp, devicesKey, p2pRank); |
1210 | } |
1211 | |
1212 | at::cuda::OptionalCUDAGuard gpuGuard; |
1213 | |
1214 | std::vector<at::cuda::CUDAStream> streamVal; |
1215 | streamVal.reserve(devices.size()); |
1216 | |
1217 | // [Group Start/End Note] This is used to ensure that nccl communicator will |
1218 | // be created before communication primitives are called. Let's look at this |
1219 | // example: Using the batch_isend_irecv to send a tensor to a target process. |
1220 | // On the sender side, the corresponding underlying NCCL calls will look like |
1221 | // ncclGroupStart() // This is in batch_isend_irecv |
1222 | // ncclGroupStart() // This is [Note 1] |
1223 | // ncclCommInitRank() // Inside NCCLComm::create |
1224 | // ncclSend() |
1225 | // ncclGroupEnd() // This is [Note 2] |
1226 | // ncclGroupEnd() // This is in batch_isend_irecv |
1227 | // With this pattern, the nccl communicator will be created in the last |
1228 | // ncclGroupEnd which means when ncclSend is processed, the passed |
1229 | // communicator argument is NULL which will lead to runtime error. So we need |
1230 | // to "close" all active nccl groups to ensure nccl communicator is actually |
1231 | // created before encountering any communication calls. This is why we need |
1232 | // the following for loop. |
1233 | for (const auto i : c10::irange(ncclActiveGroupCounter_)) { |
1234 | (void)i; |
1235 | C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt); |
1236 | } |
1237 | |
1238 | // [Note 1] Create the NCCL communicators for each GPU |
1239 | C10D_NCCL_CHECK(ncclGroupStart(), c10::nullopt); |
1240 | |
1241 | for (const auto i : c10::irange(devices.size())) { |
1242 | // GPU world size and GPU rank |
1243 | int numRanks, rank; |
1244 | |
1245 | if (!singleP2POp) { |
1246 | // Collective, all-to-all, or batch P2P |
1247 | numRanks = getSize() * devices.size(); |
1248 | rank = getRank() * devices.size() + i; |
1249 | } else if (isSendRecvSelf) { |
1250 | // Same process send and recv. |
1251 | numRanks = 1; |
1252 | rank = 0; |
1253 | } else { |
1254 | // For single point-to-point operation, there are only 2 processes |
1255 | // involved so the GPU rank is either 0 or 1. |
1256 | numRanks = 2; |
1257 | rank = p2pRank; |
1258 | } |
1259 | // Get the device index |
1260 | int deviceIndex = devices[i].index(); |
1261 | |
1262 | gpuGuard.set_index(deviceIndex); |
1263 | ncclComms[i] = NCCLComm::create(numRanks, rank, ncclID); |
1264 | |
1265 | // Creates the NCCL streams |
1266 | streamVal.push_back( |
1267 | at::cuda::getStreamFromPool(options_->is_high_priority_stream)); |
1268 | } |
1269 | |
1270 | // [Note 2 ] |
1271 | C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt); |
1272 | |
1273 | // At this point NCCL should have been initialized, hence we can accurately |
1274 | // get the env value even if NCCL sets it by reading from nccl.conf file |
1275 | if (getRank() == 0) { |
1276 | LOG(INFO) << "NCCL_DEBUG: " << parse_env("NCCL_DEBUG" ); |
1277 | } |
1278 | |
1279 | // See [Group Start/End Note] |
1280 | for (const auto i : c10::irange(ncclActiveGroupCounter_)) { |
1281 | (void)i; |
1282 | C10D_NCCL_CHECK(ncclGroupStart(), c10::nullopt); |
1283 | } |
1284 | |
1285 | ncclStreams_.emplace(devicesKey, std::move(streamVal)); |
1286 | |
1287 | // Note: these events are created with the (default) cudaEventDisableTiming |
1288 | // flag This flag provides the best performance when used with |
1289 | // cudaStreamWaitEvent() and cudaEventQuery(). Since we here don't measure the |
1290 | // performance using cudaEvent, this should be set. |
1291 | ncclEvents_.emplace( |
1292 | std::piecewise_construct, |
1293 | std::make_tuple(devicesKey), |
1294 | std::make_tuple(devices.size())); |
1295 | |
1296 | // Hold the lock before modifying the cache. |
1297 | std::lock_guard<std::mutex> lock(mutex_); |
1298 | |
1299 | // Record the communicators based on ncclUniqueId. |
1300 | ncclIdToCommMap_.emplace(buildNcclUniqueIdStr(ncclID), ncclComms); |
1301 | |
1302 | // Move the NCCL resource to cache |
1303 | devNCCLCommMap_.emplace(devicesKey, std::move(ncclComms)); |
1304 | return devNCCLCommMap_[devicesKey]; |
1305 | } |
1306 | |
1307 | namespace { |
1308 | |
1309 | // Check validity of tensor |
1310 | void check_gpu_single_tensor(const at::Tensor& tensor) { |
1311 | if (!tensor.is_cuda() || tensor.is_sparse()) { |
1312 | TORCH_CHECK(false, "Tensors must be CUDA and dense" ); |
1313 | } |
1314 | if (!tensor.is_contiguous(tensor.suggest_memory_format())) { |
1315 | TORCH_CHECK(false, "Tensors must be contiguous" ); |
1316 | } |
1317 | } |
1318 | |
1319 | // Checks that all `tensors' have the same type and shape and reside on distinct |
1320 | // GPUs. |
1321 | // TODO: test_c10d_nccl.py should consider adding tests for the error conditions |
1322 | // here, ie, that deliberately pass invalid tensors and check the right |
1323 | // exception is thrown. |
1324 | void check_gpu_tensors_different_devices( |
1325 | const std::vector<at::Tensor>& tensors) { |
1326 | if (tensors.size() == 0) { |
1327 | TORCH_CHECK(false, "Tensor list must be nonempty" ); |
1328 | } |
1329 | if (tensors.size() > static_cast<size_t>(at::cuda::getNumGPUs())) { |
1330 | TORCH_CHECK( |
1331 | false, |
1332 | "Tensor list mustn't be larger than the number of available GPUs" ); |
1333 | } |
1334 | |
1335 | const auto& first = tensors.front(); |
1336 | |
1337 | // Set for ensuring that tensors are on separate devices. |
1338 | std::unordered_set<decltype(first.get_device())> usedDevices; |
1339 | usedDevices.reserve(tensors.size()); |
1340 | |
1341 | for (const auto& t : tensors) { |
1342 | if (!t.is_cuda() || t.is_sparse()) { |
1343 | TORCH_CHECK(false, "Tensors must be CUDA and dense" ); |
1344 | } |
1345 | if (t.scalar_type() != first.scalar_type()) { |
1346 | TORCH_CHECK(false, "Tensors must have identical type" ); |
1347 | } |
1348 | if (t.sizes() != first.sizes()) { |
1349 | TORCH_CHECK(false, "Tensors must have identical size" ); |
1350 | } |
1351 | if (t.strides() != first.strides()) { |
1352 | TORCH_CHECK(false, "Tensors must have identical strides" ); |
1353 | } |
1354 | if (!t.is_contiguous(t.suggest_memory_format())) { |
1355 | TORCH_CHECK(false, "Tensors must be contiguous" ); |
1356 | } |
1357 | const auto inserted = usedDevices.insert(t.get_device()).second; |
1358 | if (!inserted) { |
1359 | TORCH_CHECK(false, "Tensors must be on distinct GPU devices" ); |
1360 | } |
1361 | } |
1362 | } |
1363 | |
1364 | // Checks that all `tensors' have the same type and shape and reside on the same |
1365 | // GPU. |
1366 | // TODO: test_c10d_nccl.py should consider adding tests for the error conditions |
1367 | // here, ie, that deliberately pass invalid tensors and check the right |
1368 | // exception is thrown. The "Expected list of tensors on the same device" |
1369 | // condition may be a challenge because the test would need to pass tensors on |
1370 | // different devices in the same process. |
1371 | int64_t check_gpu_tensors_same_device(const std::vector<at::Tensor>& tensors) { |
1372 | if (tensors.size() == 0) { |
1373 | TORCH_CHECK(false, "Tensor list must be nonempty" ); |
1374 | } |
1375 | |
1376 | const auto& first = tensors.front(); |
1377 | |
1378 | int64_t total_numel = 0; |
1379 | for (const auto& t : tensors) { |
1380 | if (!t.is_cuda() || t.is_sparse()) { |
1381 | TORCH_CHECK(false, "Tensors must be CUDA and dense" ); |
1382 | } |
1383 | if (t.scalar_type() != first.scalar_type()) { |
1384 | TORCH_CHECK(false, "Tensors must have identical type" ); |
1385 | } |
1386 | if (!t.is_non_overlapping_and_dense()) { |
1387 | TORCH_CHECK(false, "Tensors must be non-overlapping and dense" ); |
1388 | } |
1389 | // If we're in this function, the user called a _coalesced collective |
1390 | // on a set of tensors with potentially different sizes and strides. |
1391 | // Therefore, we don't check for matching sizes and strides, |
1392 | // but we do double-check tensors are on the same device. |
1393 | TORCH_CHECK( |
1394 | t.get_device() == tensors[0].get_device(), |
1395 | "Expected list of tensors on the same device" ); |
1396 | total_numel += t.numel(); |
1397 | } |
1398 | |
1399 | return total_numel; |
1400 | } |
1401 | |
1402 | bool check_same_size(const std::vector<at::Tensor>& input_tensors) { |
1403 | for (const auto& input_tensor : input_tensors) { |
1404 | if (!input_tensors[0].is_same_size(input_tensor)) { |
1405 | return false; |
1406 | } |
1407 | } |
1408 | return true; |
1409 | } |
1410 | |
1411 | // Flatten each list in `tensor_lists' for a gather or scatter operation, and |
1412 | // ensure compatibility with the corresponding tensor in `other'. |
1413 | std::vector<at::Tensor> flatten_for_scatter_gather( |
1414 | std::vector<std::vector<at::Tensor>>& tensor_lists, |
1415 | std::vector<at::Tensor>& other, |
1416 | size_t world_size) { |
1417 | if (tensor_lists.size() != other.size()) { |
1418 | TORCH_CHECK( |
1419 | false, |
1420 | "Tensor list operands to scatter/gather must have the same length" ); |
1421 | } |
1422 | const auto num_devices = tensor_lists.size(); |
1423 | |
1424 | std::vector<at::Tensor> flattened; |
1425 | flattened.resize(num_devices); |
1426 | |
1427 | for (const auto i : c10::irange(size_t{}, num_devices)) { |
1428 | if (tensor_lists[i].size() != world_size * num_devices) { |
1429 | TORCH_CHECK( |
1430 | false, |
1431 | "Tensor list input to scatter/gather must match number of collective" |
1432 | " participants" ); |
1433 | } |
1434 | |
1435 | // Only check device match for the first tensor in the list; the call to |
1436 | // newLikeFlat() below will check the rest. |
1437 | if (tensor_lists[i].front().get_device() != other[i].get_device()) { |
1438 | TORCH_CHECK( |
1439 | false, |
1440 | "Corresponding input/output tensors to scatter/gather must all reside" |
1441 | " on the same device" ); |
1442 | } |
1443 | |
1444 | for (const auto& t : tensor_lists[i]) { |
1445 | if (t.numel() != other[i].numel()) { |
1446 | TORCH_CHECK( |
1447 | false, |
1448 | "All tensor operands to scatter/gather must have the same number of elements" ); |
1449 | } |
1450 | } |
1451 | // Flatten the tensors (from all ranks) into a single big tensor. |
1452 | flattened[i] = newLikeFlat(tensor_lists, i); |
1453 | } |
1454 | return flattened; |
1455 | } |
1456 | |
1457 | } // namespace |
1458 | |
1459 | c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> ProcessGroupNCCL::initWork( |
1460 | std::vector<at::Device> devices, |
1461 | int rank, |
1462 | OpType opType, |
1463 | const char* profilingTitle, |
1464 | const c10::optional<std::vector<at::Tensor>>& inputs) { |
1465 | return c10::make_intrusive<ProcessGroupNCCL::WorkNCCL>( |
1466 | devices, rank, opType, seq_, profilingTitle, inputs, desyncDebug_); |
1467 | } |
1468 | |
1469 | std::vector<at::Tensor> ProcessGroupNCCL::WorkNCCL::result() { |
1470 | return *outputs_; |
1471 | } |
1472 | |
1473 | c10::intrusive_ptr<c10::ivalue::Future> ProcessGroupNCCL::WorkNCCL:: |
1474 | getFuture() { |
1475 | return future_; |
1476 | } |
1477 | |
1478 | void ProcessGroupNCCL::workEnqueue( |
1479 | c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> work) { |
1480 | if (!terminateProcessGroup_.load()) { |
1481 | std::lock_guard<std::mutex> lock(workMetaListMutex_); |
1482 | // Avoid view tensors to be processed in cleanup thread. |
1483 | // View tensors' destruction invokes autograd_meta, which |
1484 | // needs to be destructed in user thread. Otherwise will |
1485 | // get deadlock. Here we enqueue work without outputs_. |
1486 | workMetaList_.emplace_back(*work); |
1487 | } |
1488 | } |
1489 | |
1490 | ProcessGroupNCCL::Options::Options(bool is_high_priority_stream) |
1491 | : Backend::Options(NCCL_BACKEND_NAME), |
1492 | is_high_priority_stream(is_high_priority_stream) {} |
1493 | |
1494 | void ProcessGroupNCCL::startCoalescing() { |
1495 | coalescedDevices_.clear(); |
1496 | coalescing_active_ = true; |
1497 | groupStart(); |
1498 | } |
1499 | |
1500 | void ProcessGroupNCCL::endCoalescing( |
1501 | std::vector<c10::intrusive_ptr<Work>>& reqs) { |
1502 | groupEnd(); |
1503 | if (reqs.size() != coalescedDevices_.size()) { |
1504 | TORCH_CHECK(false, "Number of requests do not match number of collectives" ); |
1505 | } |
1506 | |
1507 | int batch_idx = 0; |
1508 | for (const auto& req : reqs) { |
1509 | auto ncclWork = static_cast<ProcessGroupNCCL::WorkNCCL*>(req.get()); |
1510 | // @lint-ignore CLANGTIDY |
1511 | std::vector<at::Device> devices = coalescedDevices_[batch_idx]; |
1512 | const auto key = getKeyFromDevices(devices); |
1513 | auto& ncclStreams = ncclStreams_[key]; |
1514 | for (const auto i : c10::irange(devices.size())) { |
1515 | (*ncclWork->ncclEndEvents_)[i].record(ncclStreams[i]); |
1516 | } |
1517 | batch_idx += 1; |
1518 | } |
1519 | coalescing_active_ = false; |
1520 | } |
1521 | |
1522 | template <typename Fn, typename PreProcess, typename PostProcess> |
1523 | c10::intrusive_ptr<Work> ProcessGroupNCCL::collective( |
1524 | std::vector<at::Tensor>& inputs, |
1525 | std::vector<at::Tensor>& outputs, |
1526 | Fn fn, |
1527 | PreProcess pre, |
1528 | PostProcess post, |
1529 | OpType opType, |
1530 | const char* profilingTitle) { |
1531 | errorIfCapturingNonCapturableNCCL(); |
1532 | |
1533 | // Bump collective counter |
1534 | seq_++; |
1535 | |
1536 | // Currently, the API permits two scenarios where inputs.size() and |
1537 | // outputs.size() are > 0. |
1538 | // 1. If the call was a _coalesced call, all inputs must be on the same |
1539 | // device. |
1540 | // The group of nccl calls applies the collective separately to each input, |
1541 | // but the group as a whole should be efficient, and might even execute as |
1542 | // a single fused kernel. |
1543 | // 2. If the call was a _multigpu call, all inputs must be on different |
1544 | // devices. |
1545 | // The nccl group applies the collective across them (eg, if the collective |
1546 | // is an allreduce, the output on each device contains contributions summed |
1547 | // across `inputs' tensors). |
1548 | const auto devices = getDeviceList(inputs); |
1549 | const bool inputs_same_dev = (devices.size() == 1); |
1550 | const auto key = getKeyFromDevices(devices); |
1551 | auto& ncclComms = getNCCLComm(key, devices, opType); |
1552 | |
1553 | if (coalescing_active_) { |
1554 | coalescedDevices_.push_back(devices); |
1555 | } |
1556 | |
1557 | // Used many times below, so we stash the unordered_map lookup |
1558 | auto& ncclStreams = ncclStreams_[key]; |
1559 | |
1560 | // First let NCCL streams wait for input tensors allocation streams |
1561 | syncStreams(devices, ncclEvents_[key], ncclStreams); |
1562 | |
1563 | // Work itself will create the CUDA events on all GPUs of tensors |
1564 | bool can_profile = outputs.size() == 1; |
1565 | auto work = initWork( |
1566 | devices, |
1567 | rank_, |
1568 | opType, |
1569 | can_profile ? profilingTitle : nullptr, |
1570 | can_profile ? c10::optional<std::vector<at::Tensor>>(inputs) |
1571 | : c10::nullopt); |
1572 | |
1573 | // Store references to outputs to be used by WorkNCCL::result and operator<<. |
1574 | work->outputs_ = std::make_shared<std::vector<at::Tensor>>(outputs); |
1575 | |
1576 | at::cuda::OptionalCUDAGuard gpuGuard; |
1577 | |
1578 | // Start event should only be recorded before the ncclGroupStart() |
1579 | if (desyncDebug_) { |
1580 | for (const auto i : c10::irange(devices.size())) { |
1581 | at::cuda::CUDAStream& ncclStream = ncclStreams[i]; |
1582 | (*work->ncclStartEvents_)[i].record(ncclStream); |
1583 | } |
1584 | } |
1585 | |
1586 | pre(ncclStreams); |
1587 | |
1588 | { |
1589 | torch::cuda::nccl::AutoNcclGroup nccl_group_guard; |
1590 | for (const auto i : c10::irange(inputs.size())) { |
1591 | if (!inputs_same_dev || (inputs_same_dev && i == 0)) { |
1592 | gpuGuard.set_index(devices[i].index()); |
1593 | } |
1594 | decltype(i) stream_comm_i = (inputs_same_dev ? 0 : i); |
1595 | auto& ncclStream = ncclStreams[stream_comm_i]; |
1596 | auto& ncclComm = ncclComms[stream_comm_i]; |
1597 | // Both `inputs' and `outputs' are created on a worker stream and used in |
1598 | // different ncclStreams. Hence, both must record the ncclStream to |
1599 | // prevent being freed before the collective finishes. |
1600 | // |
1601 | // We only record `inputs' here, and leave recording `outputs' to `fn' for |
1602 | // operations where `inputs' and `outputs' are not the same. |
1603 | // |
1604 | // See [Sync Streams]. |
1605 | c10::cuda::CUDACachingAllocator::recordStream( |
1606 | inputs[i].storage().data_ptr(), ncclStream); |
1607 | C10D_NCCL_CHECK( |
1608 | fn(inputs[i], outputs[i], ncclComm->getNcclComm(), ncclStream), |
1609 | ncclComm->getNcclCommFailureReason()); |
1610 | } |
1611 | } |
1612 | |
1613 | post(ncclStreams); |
1614 | |
1615 | // End event should only be recorded after the ncclGroupEnd() |
1616 | for (const auto i : c10::irange(devices.size())) { |
1617 | at::cuda::CUDAStream& ncclStream = ncclStreams[i]; |
1618 | if (!coalescing_active_) { |
1619 | (*work->ncclEndEvents_)[i].record(ncclStream); |
1620 | } |
1621 | work->ncclComms_[i] = ncclComms[i]; |
1622 | } |
1623 | |
1624 | { |
1625 | c10::cuda::CUDAMultiStreamGuard streamGuard(ncclStreams); |
1626 | work->future_ = c10::make_intrusive<at::ivalue::Future>( |
1627 | c10::ListType::create(c10::TensorType::get()), devices); |
1628 | |
1629 | // Add a callback that runs profiling end callbacks. wrapCallback() in CUDA |
1630 | // future blocks the stream this callback runs on the corresponding |
1631 | // ncclEndEvents_ ensuring appropriate synchronization. |
1632 | if (work->recordFunctionEndCallback_) { |
1633 | work->future_->addCallback([work](at::ivalue::Future& /* unused */) { |
1634 | work->recordFunctionEndCallback_(); |
1635 | }); |
1636 | } |
1637 | work->future_->markCompleted(at::IValue(*work->outputs_)); |
1638 | } |
1639 | |
1640 | // Set appropriate work parameters. |
1641 | work->blockingWait_ = blockingWait_; |
1642 | work->opTimeout_ = options_->timeout; |
1643 | work->store_ = store_; |
1644 | |
1645 | if (asyncErrorHandling_ != NoHandling) { |
1646 | workEnqueue(work); |
1647 | } |
1648 | |
1649 | return work; |
1650 | } |
1651 | |
1652 | template <typename Fn, typename PreProcess, typename PostProcess> |
1653 | c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint( |
1654 | std::vector<at::Tensor>& tensors, |
1655 | Fn fn, |
1656 | int peer, |
1657 | OpType opType, |
1658 | PreProcess pre, |
1659 | PostProcess post, |
1660 | const char* profilingTitle) { |
1661 | const auto devices = getDeviceList(tensors); |
1662 | std::string key; |
1663 | int p2pRank = 0, p2pTargetRank = 0; |
1664 | bool isSendRecvSelf = false; |
1665 | // For batch_isend_irecv, ncclGroupStart() would be called upfront |
1666 | bool batchP2P = ncclActiveGroupCounter_ > 0; |
1667 | if (batchP2P) { |
1668 | // For batch P2P, we need to treat it like a collective when selecting |
1669 | // communicator, because other ranks can call into this batch other than my |
1670 | // rank and my peer |
1671 | key = getKeyFromDevices(devices); |
1672 | p2pRank = rank_; |
1673 | p2pTargetRank = peer; |
1674 | } else { |
1675 | // For single P2P, preserve the old two-rank behavior (to avoid perf diff) |
1676 | key = getKeySendRecv(rank_, peer); |
1677 | p2pRank = rank_ <= peer ? 0 : 1; |
1678 | isSendRecvSelf = rank_ == peer; |
1679 | p2pTargetRank = isSendRecvSelf ? 0 : 1 - p2pRank; |
1680 | } |
1681 | auto& ncclComms = getNCCLComm(key, devices, opType, p2pRank, isSendRecvSelf); |
1682 | |
1683 | if (coalescing_active_) { |
1684 | coalescedDevices_.push_back(devices); |
1685 | } |
1686 | |
1687 | // First let NCCL streams wait for input tensors allocation streams |
1688 | syncStreams(devices, ncclEvents_[key], ncclStreams_[key]); |
1689 | |
1690 | // Work itself will create the CUDA events on all GPUs of tensors |
1691 | bool can_profile = tensors.size() == 1; |
1692 | auto work = initWork( |
1693 | devices, |
1694 | rank_, |
1695 | opType, |
1696 | can_profile ? profilingTitle : nullptr, |
1697 | can_profile ? c10::optional<std::vector<at::Tensor>>(tensors) |
1698 | : c10::nullopt); |
1699 | |
1700 | // Store references to outputs to be used by WorkNCCL::result and operator<<. |
1701 | // Note that these outputs are only valid for recv(), as send() does not |
1702 | // modify the inputs but we still create these outputs for use cases such as |
1703 | // profiling. |
1704 | work->outputs_ = std::make_shared<std::vector<at::Tensor>>(tensors); |
1705 | |
1706 | at::cuda::OptionalCUDAGuard gpuGuard; |
1707 | |
1708 | // Start event should only be recorded before the ncclGroupStart() |
1709 | if (desyncDebug_) { |
1710 | for (const auto i : c10::irange(tensors.size())) { |
1711 | at::cuda::CUDAStream& ncclStream = ncclStreams_[key][i]; |
1712 | (*work->ncclStartEvents_)[i].record(ncclStream); |
1713 | } |
1714 | } |
1715 | |
1716 | pre(ncclStreams_[key]); |
1717 | |
1718 | for (const auto i : c10::irange(tensors.size())) { |
1719 | gpuGuard.set_index(devices[i].index()); |
1720 | at::cuda::CUDAStream& ncclStream = ncclStreams_[key][i]; |
1721 | |
1722 | // Both send tensor and recv tensor are created on a worker stream and used |
1723 | // in different ncclStreams. Hence, both must record the ncclStream to |
1724 | // prevent being freed before the collective finishes. |
1725 | // |
1726 | // See [Sync Streams]. |
1727 | c10::cuda::CUDACachingAllocator::recordStream( |
1728 | tensors[i].storage().data_ptr(), ncclStream); |
1729 | } |
1730 | |
1731 | { |
1732 | torch::cuda::nccl::AutoNcclGroup nccl_group_guard; |
1733 | for (const auto i : c10::irange(tensors.size())) { |
1734 | gpuGuard.set_index(devices[i].index()); |
1735 | at::cuda::CUDAStream& ncclStream = ncclStreams_[key][i]; |
1736 | C10D_NCCL_CHECK( |
1737 | fn(tensors[i], |
1738 | ncclComms[i]->getNcclComm(), |
1739 | ncclStream, |
1740 | p2pTargetRank), |
1741 | ncclComms[i]->getNcclCommFailureReason()); |
1742 | } |
1743 | } |
1744 | |
1745 | post(ncclStreams_[key]); |
1746 | |
1747 | // End event should only be recorded after the ncclGroupEnd() |
1748 | for (const auto i : c10::irange(tensors.size())) { |
1749 | at::cuda::CUDAStream& ncclStream = ncclStreams_[key][i]; |
1750 | if (!coalescing_active_) { |
1751 | (*work->ncclEndEvents_)[i].record(ncclStream); |
1752 | } |
1753 | work->ncclComms_[i] = ncclComms[i]; |
1754 | work->blockingWait_ = blockingWait_; |
1755 | work->opTimeout_ = options_->timeout; |
1756 | work->store_ = store_; |
1757 | } |
1758 | |
1759 | // Future only needs to be created and marked completed with outputs for |
1760 | // recv(), but still create future for use cases such as profiling even for |
1761 | // send(). |
1762 | { |
1763 | c10::cuda::CUDAMultiStreamGuard streamGuard(ncclStreams_[key]); |
1764 | work->future_ = c10::make_intrusive<at::ivalue::Future>( |
1765 | c10::ListType::create(c10::TensorType::get()), devices); |
1766 | work->future_->markCompleted(at::IValue(*work->outputs_)); |
1767 | } |
1768 | |
1769 | // Add a callback that runs profiling end callbacks. wrapCallback() in CUDA |
1770 | // future blocks the stream this callback runs on the corresponding |
1771 | // ncclEndEvents_ ensuring appropriate synchronization. |
1772 | if (work->recordFunctionEndCallback_) { |
1773 | work->future_->addCallback([work](at::ivalue::Future& /* unused */) { |
1774 | work->recordFunctionEndCallback_(); |
1775 | }); |
1776 | } |
1777 | |
1778 | return work; |
1779 | } |
1780 | |
1781 | template <typename Fn> |
1782 | c10::intrusive_ptr<Work> ProcessGroupNCCL::collective( |
1783 | std::vector<at::Tensor>& inputs, |
1784 | std::vector<at::Tensor>& outputs, |
1785 | Fn fn, |
1786 | OpType opType, |
1787 | const char* profilingTitle) { |
1788 | return collective( |
1789 | inputs, |
1790 | outputs, |
1791 | fn, |
1792 | [](std::vector<at::cuda::CUDAStream>&) {}, |
1793 | [](std::vector<at::cuda::CUDAStream>&) {}, |
1794 | opType, |
1795 | profilingTitle); |
1796 | } |
1797 | |
1798 | template <typename Fn> |
1799 | c10::intrusive_ptr<Work> ProcessGroupNCCL::pointToPoint( |
1800 | std::vector<at::Tensor>& tensor, |
1801 | Fn fn, |
1802 | int peer, |
1803 | OpType opType, |
1804 | const char* profilingTitle) { |
1805 | return pointToPoint( |
1806 | tensor, |
1807 | fn, |
1808 | peer, |
1809 | opType, |
1810 | [](std::vector<at::cuda::CUDAStream>&) {}, |
1811 | [](std::vector<at::cuda::CUDAStream>&) {}, |
1812 | profilingTitle); |
1813 | } |
1814 | |
1815 | c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce_impl( |
1816 | std::vector<at::Tensor>& tensors, |
1817 | const AllreduceOptions& opts) { |
1818 | int dev_in_group = 0; |
1819 | return collective( |
1820 | tensors, |
1821 | tensors, |
1822 | [&](at::Tensor& input, |
1823 | at::Tensor& output, |
1824 | ncclComm_t comm, |
1825 | at::cuda::CUDAStream& stream) { |
1826 | auto ncclDataType = getNcclDataType(input.scalar_type()); |
1827 | auto ncclReduceOp = getNcclReduceOp( |
1828 | opts.reduceOp, input, ncclDataType, comm, dev_in_group++); |
1829 | return ncclAllReduce( |
1830 | input.data_ptr(), |
1831 | output.data_ptr(), |
1832 | input.numel(), |
1833 | ncclDataType, |
1834 | ncclReduceOp, |
1835 | comm, |
1836 | stream.stream()); |
1837 | }, |
1838 | OpType::ALLREDUCE, |
1839 | "nccl:all_reduce" ); |
1840 | } |
1841 | |
1842 | c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce( |
1843 | std::vector<at::Tensor>& tensors, |
1844 | const AllreduceOptions& opts) { |
1845 | check_gpu_tensors_different_devices(tensors); |
1846 | |
1847 | // @lint-ignore CLANGTIDY |
1848 | auto tensor = tensors.back(); |
1849 | RECORD_PARAM_COMMS_DATA( |
1850 | static_cast<int>( |
1851 | this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective |
1852 | reinterpret_cast<std::intptr_t>(this), // process group ptr |
1853 | tensors, // inputTensors |
1854 | tensors, // outputTensors |
1855 | rank_, // rank |
1856 | "allreduce" , // colName |
1857 | tensor.numel(), // inSize |
1858 | tensor.numel(), // outSize |
1859 | tensor.scalar_type(), // dType |
1860 | std::vector<int64_t>(), // inSplitSizes |
1861 | std::vector<int64_t>()); // outSplitSizes |
1862 | |
1863 | return allreduce_impl(tensors, opts); |
1864 | } |
1865 | |
1866 | c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce_coalesced( |
1867 | std::vector<at::Tensor>& tensors, |
1868 | const AllreduceCoalescedOptions& opts) { |
1869 | auto total_numel = check_gpu_tensors_same_device(tensors); |
1870 | |
1871 | // @lint-ignore CLANGTIDY |
1872 | RECORD_PARAM_COMMS_DATA( |
1873 | static_cast<int>( |
1874 | this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective |
1875 | reinterpret_cast<std::intptr_t>(this), // process group ptr |
1876 | tensors, // inputTensors |
1877 | tensors, // outputTensors |
1878 | rank_, // rank |
1879 | "allreduce_coalesced" , // colName |
1880 | total_numel, // inSize |
1881 | total_numel, // outSize |
1882 | tensors[0].scalar_type(), // dType |
1883 | // I'm not sure what in,outSplitSizes mean here. |
1884 | std::vector<int64_t>(), // inSplitSizes |
1885 | std::vector<int64_t>()); // outSplitSizes |
1886 | |
1887 | return allreduce_impl(tensors, opts); |
1888 | } |
1889 | |
1890 | c10::intrusive_ptr<Work> ProcessGroupNCCL::broadcast( |
1891 | std::vector<at::Tensor>& tensors, |
1892 | const BroadcastOptions& opts) { |
1893 | check_gpu_tensors_different_devices(tensors); |
1894 | |
1895 | // @lint-ignore CLANGTIDY |
1896 | auto tensor = tensors.back(); |
1897 | |
1898 | RECORD_PARAM_COMMS_DATA( |
1899 | static_cast<int>( |
1900 | this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective |
1901 | reinterpret_cast<std::intptr_t>(this), // process group ptr |
1902 | tensors, // inputTensors |
1903 | tensors, // outputTensors |
1904 | rank_, // rank |
1905 | "broadcast" , // colName |
1906 | tensor.numel(), // inSize |
1907 | tensor.numel(), // outSize |
1908 | tensor.scalar_type(), // dType |
1909 | std::vector<int64_t>(), // inSplitSizes |
1910 | std::vector<int64_t>()); // outSplitSizes |
1911 | |
1912 | return collective( |
1913 | tensors, |
1914 | tensors, |
1915 | [&](at::Tensor& input, |
1916 | at::Tensor& output, |
1917 | ncclComm_t comm, |
1918 | at::cuda::CUDAStream& stream) { |
1919 | const auto root = opts.rootRank * tensors.size() + opts.rootTensor; |
1920 | return ncclBcast( |
1921 | input.data_ptr(), |
1922 | input.numel(), |
1923 | getNcclDataType(input.scalar_type()), |
1924 | root, |
1925 | comm, |
1926 | stream.stream()); |
1927 | }, |
1928 | OpType::BROADCAST, |
1929 | "nccl:broadcast" ); |
1930 | } |
1931 | |
1932 | // _broadcast_oop adds an out-of-place broadcast in PGNCCL |
1933 | // Custom collectives may be implemented by coalescing broadcast operations |
1934 | // One use-case is implementing a vector all_gather (all_gather_v) |
1935 | // where unevenly sized inputs are gathered among participating ranks |
1936 | // Since all_gather provides an out-of-place API, an all_gather_v |
1937 | // semantic implemented inside pg_nccl.all_gather also needs to support |
1938 | // out-of-place, for which an out-of-place broadcast is required to be added |
1939 | c10::intrusive_ptr<Work> ProcessGroupNCCL::_broadcast_oop( |
1940 | std::vector<at::Tensor>& outputTensors, |
1941 | std::vector<at::Tensor>& inputTensors, |
1942 | const BroadcastOptions& opts) { |
1943 | check_gpu_tensors_different_devices(outputTensors); |
1944 | check_gpu_tensors_different_devices(inputTensors); |
1945 | |
1946 | // @lint-ignore CLANGTIDY |
1947 | auto tensor = outputTensors.back(); |
1948 | // @lint-ignore CLANGTIDY |
1949 | auto in_tensor = inputTensors.back(); |
1950 | if (tensor.numel() != in_tensor.numel()) { |
1951 | TORCH_CHECK( |
1952 | false, |
1953 | "Tensor input and output of _broadcast_oop must have the same number of elements " ); |
1954 | } |
1955 | RECORD_PARAM_COMMS_DATA( |
1956 | static_cast<int>( |
1957 | this->getSequenceNumberForGroup() + |
1958 | 1), // seq + 1 to match collective increment. |
1959 | reinterpret_cast<std::intptr_t>(this), // process group ptr |
1960 | inputTensors, // inputTensors |
1961 | outputTensors, // outputTensors |
1962 | rank_, // rank |
1963 | "_broadcast_oop" , // colName |
1964 | tensor.numel(), // inSize |
1965 | tensor.numel(), // outSize |
1966 | tensor.scalar_type(), // dType |
1967 | std::vector<int64_t>(), // inSplitSizes |
1968 | std::vector<int64_t>()); // outSplitSizes |
1969 | |
1970 | return collective( |
1971 | inputTensors, |
1972 | outputTensors, |
1973 | [&](at::Tensor& input, |
1974 | at::Tensor& output, |
1975 | ncclComm_t comm, |
1976 | at::cuda::CUDAStream& stream) { |
1977 | const auto root = opts.rootRank * inputTensors.size() + opts.rootTensor; |
1978 | return ncclBroadcast( |
1979 | input.data_ptr(), |
1980 | output.data_ptr(), |
1981 | input.numel(), |
1982 | getNcclDataType(input.scalar_type()), |
1983 | root, |
1984 | comm, |
1985 | stream.stream()); |
1986 | }, |
1987 | OpType::BROADCAST, |
1988 | "nccl:_broadcast_oop" ); |
1989 | } |
1990 | |
1991 | c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce( |
1992 | std::vector<at::Tensor>& tensors, |
1993 | const ReduceOptions& opts) { |
1994 | check_gpu_tensors_different_devices(tensors); |
1995 | // @lint-ignore CLANGTIDY |
1996 | auto tensor = tensors.back(); |
1997 | RECORD_PARAM_COMMS_DATA( |
1998 | static_cast<int>( |
1999 | this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective |
2000 | reinterpret_cast<std::intptr_t>(this), |
2001 | tensors, // inputTensors |
2002 | tensors, // outputTensors |
2003 | rank_, // rank |
2004 | "reduce" , // colName |
2005 | tensor.numel(), // inSize |
2006 | tensor.numel(), // outSize |
2007 | tensor.scalar_type(), // dType |
2008 | std::vector<int64_t>(), // inSplitSizes |
2009 | std::vector<int64_t>()); // outSplitSizes |
2010 | |
2011 | int dev_in_group = 0; |
2012 | return collective( |
2013 | tensors, |
2014 | tensors, |
2015 | [&](at::Tensor& input, |
2016 | at::Tensor& output, |
2017 | ncclComm_t comm, |
2018 | at::cuda::CUDAStream& stream) { |
2019 | const auto root = opts.rootRank * tensors.size() + opts.rootTensor; |
2020 | auto ncclDataType = getNcclDataType(input.scalar_type()); |
2021 | auto ncclReduceOp = getNcclReduceOp( |
2022 | opts.reduceOp, input, ncclDataType, comm, dev_in_group++); |
2023 | return ncclReduce( |
2024 | input.data_ptr(), |
2025 | output.data_ptr(), |
2026 | input.numel(), |
2027 | ncclDataType, |
2028 | ncclReduceOp, |
2029 | root, |
2030 | comm, |
2031 | stream.stream()); |
2032 | }, |
2033 | OpType::REDUCE, |
2034 | "nccl:reduce" ); |
2035 | } |
2036 | |
2037 | // _reduce_oop exposes an out-of-place reduce from PGNCCL |
2038 | // Custom collectives may be implemented by coalescing reduce operations |
2039 | // One use-case is implementing a vector reduce_scatter (reduce_scatter_v) |
2040 | // where inputs are reduced and scattered unevenly among participating ranks |
2041 | // Since reduce_scatter provides an out-of-place API, a reduce_scatter_v |
2042 | // semantic implemented inside pg_nccl.reduce_scatter also needs to support |
2043 | // out-of-place, for which an out-of-place reduce is required to be added |
2044 | c10::intrusive_ptr<Work> ProcessGroupNCCL::_reduce_oop( |
2045 | std::vector<at::Tensor>& outputTensors, |
2046 | std::vector<at::Tensor>& inputTensors, |
2047 | const ReduceOptions& opts) { |
2048 | check_gpu_tensors_different_devices(outputTensors); |
2049 | check_gpu_tensors_different_devices(inputTensors); |
2050 | // @lint-ignore CLANGTIDY |
2051 | auto tensor = outputTensors.back(); |
2052 | // @lint-ignore CLANGTIDY |
2053 | auto in_tensor = inputTensors.back(); |
2054 | if (tensor.numel() != in_tensor.numel()) { |
2055 | TORCH_CHECK( |
2056 | false, |
2057 | "Tensor input and output of _reduce_oop must have the same number of elements " ); |
2058 | } |
2059 | RECORD_PARAM_COMMS_DATA( |
2060 | static_cast<int>( |
2061 | this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective |
2062 | reinterpret_cast<std::intptr_t>(this), // process group ptr |
2063 | inputTensors, // inputTensors |
2064 | outputTensors, // outputTensors |
2065 | rank_, // rank |
2066 | "_reduce_oop" , // colName |
2067 | tensor.numel(), // inSize |
2068 | tensor.numel(), // outSize |
2069 | tensor.scalar_type(), // dType |
2070 | std::vector<int64_t>(), // inSplitSizes |
2071 | std::vector<int64_t>()); // outSplitSizes |
2072 | |
2073 | int dev_in_group{0}; |
2074 | return collective( |
2075 | inputTensors, |
2076 | outputTensors, |
2077 | [&](at::Tensor& input, |
2078 | at::Tensor& output, |
2079 | ncclComm_t comm, |
2080 | at::cuda::CUDAStream& stream) { |
2081 | const auto root = opts.rootRank * inputTensors.size() + opts.rootTensor; |
2082 | const auto ncclDataType = getNcclDataType(input.scalar_type()); |
2083 | const auto ncclReduceOp = getNcclReduceOp( |
2084 | opts.reduceOp, input, ncclDataType, comm, dev_in_group++); |
2085 | return ncclReduce( |
2086 | input.data_ptr(), |
2087 | output.data_ptr(), |
2088 | input.numel(), |
2089 | ncclDataType, |
2090 | ncclReduceOp, |
2091 | (int)root, |
2092 | comm, |
2093 | stream.stream()); |
2094 | }, |
2095 | OpType::REDUCE, |
2096 | "nccl:_reduce_oop" ); |
2097 | } |
2098 | |
2099 | c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather( |
2100 | std::vector<std::vector<at::Tensor>>& outputTensors, |
2101 | std::vector<at::Tensor>& inputTensors, |
2102 | const AllgatherOptions& opts) { |
2103 | check_gpu_tensors_different_devices(inputTensors); |
2104 | // @lint-ignore CLANGTIDY |
2105 | bool same_size = check_same_size(outputTensors.back()); |
2106 | |
2107 | if (same_size) { |
2108 | auto outputFlattened = |
2109 | flatten_for_scatter_gather(outputTensors, inputTensors, size_); |
2110 | check_gpu_tensors_different_devices(outputFlattened); |
2111 | |
2112 | // @lint-ignore CLANGTIDY |
2113 | auto tensor = inputTensors.back(); |
2114 | RECORD_PARAM_COMMS_DATA( |
2115 | static_cast<int>( |
2116 | this->getSequenceNumberForGroup() + |
2117 | 1), // seq + 1 to match collective |
2118 | reinterpret_cast<std::intptr_t>(this), // process group ptr |
2119 | inputTensors, // inputTensors |
2120 | outputTensors, // outputTensors |
2121 | rank_, // rank |
2122 | "all_gather" , // colName |
2123 | tensor.numel(), // inSize |
2124 | tensor.numel() * // outSize |
2125 | this->getSize(), // dType |
2126 | tensor.scalar_type(), |
2127 | std::vector<int64_t>(), // inSplitSizes |
2128 | std::vector<int64_t>()); // outSplitSize |
2129 | |
2130 | return collective( |
2131 | inputTensors, |
2132 | outputFlattened, |
2133 | [&](at::Tensor& input, |
2134 | at::Tensor& output, |
2135 | ncclComm_t comm, |
2136 | at::cuda::CUDAStream& stream) { |
2137 | c10::cuda::CUDACachingAllocator::recordStream( |
2138 | output.storage().data_ptr(), stream); |
2139 | return ncclAllGather( |
2140 | input.data_ptr(), |
2141 | output.data_ptr(), |
2142 | input.numel(), |
2143 | getNcclDataType(input.scalar_type()), |
2144 | comm, |
2145 | stream.stream()); |
2146 | }, |
2147 | [&](std::vector<at::cuda::CUDAStream>& ncclStreams) {}, |
2148 | [&](std::vector<at::cuda::CUDAStream>& ncclStreams) { |
2149 | // Copy the flattened output tensors to the outputs. |
2150 | for (const auto i : c10::irange(outputTensors.size())) { |
2151 | at::cuda::CUDAStreamGuard guard(ncclStreams[i]); |
2152 | for (const auto j : c10::irange(outputTensors[0].size())) { |
2153 | // See [Sync Streams]. |
2154 | c10::cuda::CUDACachingAllocator::recordStream( |
2155 | outputTensors[i][j].storage().data_ptr(), ncclStreams[i]); |
2156 | |
2157 | outputTensors[i][j].copy_(outputFlattened[i][j], true); |
2158 | } |
2159 | } |
2160 | }, |
2161 | OpType::ALLGATHER, |
2162 | "nccl:all_gather" ); |
2163 | } else { |
2164 | const auto num_devices = outputTensors.size(); |
2165 | const auto num_reduces = outputTensors[0].size(); |
2166 | std::vector<c10::intrusive_ptr<Work>> works; |
2167 | startCoalescing(); |
2168 | for (const auto i : c10::irange(num_reduces)) { |
2169 | std::vector<at::Tensor> inputs_multi_dev(num_devices); |
2170 | std::vector<at::Tensor> outputs_multi_dev(num_devices); |
2171 | for (const auto j : c10::irange(num_devices)) { |
2172 | // @lint-ignore CLANGTIDY |
2173 | outputs_multi_dev[j] = outputTensors[j][i]; |
2174 | inputs_multi_dev[j] = |
2175 | // @lint-ignore CLANGTIDY |
2176 | i == (rank_ * num_devices + j) ? inputTensors[j] |
2177 | : outputs_multi_dev[j]; |
2178 | } |
2179 | auto broadcastOpts = BroadcastOptions{ |
2180 | static_cast<int64_t>(i / num_devices), |
2181 | static_cast<int64_t>(i % num_devices), |
2182 | opts.timeout}; |
2183 | auto work = |
2184 | _broadcast_oop(outputs_multi_dev, inputs_multi_dev, broadcastOpts); |
2185 | works.push_back(work); |
2186 | } |
2187 | endCoalescing(works); |
2188 | return initCoalescedWork(works, rank_, OpType::BROADCAST); |
2189 | } |
2190 | } |
2191 | |
2192 | c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather_coalesced( |
2193 | std::vector<std::vector<at::Tensor>>& /* unused */, |
2194 | std::vector<at::Tensor>& /* unused */, |
2195 | const AllgatherOptions& /* unused */) { |
2196 | TORCH_CHECK(false, "ProcessGroupNCCL does not support allgather_coalesced" ); |
2197 | } |
2198 | |
2199 | c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter( |
2200 | std::vector<at::Tensor>& outputTensors, |
2201 | std::vector<std::vector<at::Tensor>>& inputTensors, |
2202 | const ReduceScatterOptions& opts) { |
2203 | check_gpu_tensors_different_devices(outputTensors); |
2204 | // @lint-ignore CLANGTIDY |
2205 | bool same_size = check_same_size(inputTensors.back()); |
2206 | |
2207 | if (same_size) { |
2208 | // @lint-ignore CLANGTIDY |
2209 | auto tensor = outputTensors.back(); |
2210 | |
2211 | int dev_in_group{0}; |
2212 | auto inputFlattened = |
2213 | flatten_for_scatter_gather(inputTensors, outputTensors, size_); |
2214 | check_gpu_tensors_different_devices(inputFlattened); |
2215 | |
2216 | RECORD_PARAM_COMMS_DATA( |
2217 | static_cast<int>( |
2218 | this->getSequenceNumberForGroup() + |
2219 | 1), // seq + 1 to match collective |
2220 | reinterpret_cast<std::intptr_t>(this), // process group ptr |
2221 | inputTensors, // inputTensors |
2222 | outputTensors, // outputTensors |
2223 | rank_, // rank |
2224 | "reduce_scatter" , // colName |
2225 | tensor.numel() * this->getSize(), // inSize |
2226 | tensor.numel(), // outSize |
2227 | tensor.scalar_type(), // dType |
2228 | std::vector<int64_t>(), // inSplitSizes |
2229 | std::vector<int64_t>()); // outSplitSizes |
2230 | |
2231 | return collective( |
2232 | inputFlattened, |
2233 | outputTensors, |
2234 | [&](at::Tensor& input, |
2235 | at::Tensor& output, |
2236 | ncclComm_t comm, |
2237 | at::cuda::CUDAStream& stream) { |
2238 | c10::cuda::CUDACachingAllocator::recordStream( |
2239 | output.storage().data_ptr(), stream); |
2240 | const auto ncclDataType = getNcclDataType(input.scalar_type()); |
2241 | const auto ncclReduceOp = getNcclReduceOp( |
2242 | opts.reduceOp, input, ncclDataType, comm, dev_in_group++); |
2243 | return ncclReduceScatter( |
2244 | input.data_ptr(), |
2245 | output.data_ptr(), |
2246 | output.numel(), |
2247 | ncclDataType, |
2248 | ncclReduceOp, |
2249 | comm, |
2250 | stream.stream()); |
2251 | }, |
2252 | [&](std::vector<at::cuda::CUDAStream>& ncclStreams) { |
2253 | // Copy the input tensors to the flattened inputs. |
2254 | for (const auto i : c10::irange(inputTensors.size())) { |
2255 | at::cuda::CUDAStreamGuard guard(ncclStreams[i]); |
2256 | for (const auto j : c10::irange(inputTensors[0].size())) { |
2257 | // See [Sync Streams]. |
2258 | c10::cuda::CUDACachingAllocator::recordStream( |
2259 | inputTensors[i][j].storage().data_ptr(), ncclStreams[i]); |
2260 | |
2261 | inputFlattened[i][j].copy_(inputTensors[i][j], true); |
2262 | } |
2263 | } |
2264 | }, |
2265 | [&](std::vector<at::cuda::CUDAStream>&) {}, |
2266 | OpType::REDUCE_SCATTER, |
2267 | "nccl:reduce_scatter" ); |
2268 | } else { |
2269 | const auto num_devices = inputTensors.size(); |
2270 | const auto num_reduces = inputTensors[0].size(); |
2271 | std::vector<c10::intrusive_ptr<Work>> works; |
2272 | startCoalescing(); |
2273 | for (const auto i : c10::irange(num_reduces)) { |
2274 | std::vector<at::Tensor> inputs_multi_dev(num_devices); |
2275 | std::vector<at::Tensor> outputs_multi_dev(num_devices); |
2276 | for (const auto j : c10::irange(num_devices)) { |
2277 | // @lint-ignore CLANGTIDY |
2278 | inputs_multi_dev[j] = inputTensors[j][i]; |
2279 | outputs_multi_dev[j] = |
2280 | // @lint-ignore CLANGTIDY |
2281 | i == (rank_ * num_devices + j) ? outputTensors[j] |
2282 | : inputs_multi_dev[j]; |
2283 | } |
2284 | auto reduceOpts = ReduceOptions{ |
2285 | opts.reduceOp, |
2286 | static_cast<int64_t>(i / num_devices), |
2287 | static_cast<int64_t>(i % num_devices), |
2288 | opts.timeout}; |
2289 | auto work = _reduce_oop(outputs_multi_dev, inputs_multi_dev, reduceOpts); |
2290 | works.push_back(work); |
2291 | } |
2292 | endCoalescing(works); |
2293 | return initCoalescedWork(works, rank_, OpType::REDUCE); |
2294 | } |
2295 | } |
2296 | |
2297 | c10::intrusive_ptr<Work> ProcessGroupNCCL::_reduce_scatter_base( |
2298 | at::Tensor& outputTensor, |
2299 | at::Tensor& inputTensor, |
2300 | const ReduceScatterOptions& opts) { |
2301 | if (inputTensor.dtype() != outputTensor.dtype()) { |
2302 | TORCH_CHECK( |
2303 | false, "input tensor must be the same type as the output tensor." ); |
2304 | } |
2305 | |
2306 | if (inputTensor.numel() != outputTensor.numel() * size_) { |
2307 | TORCH_CHECK( |
2308 | false, |
2309 | "input tensor must be the same size as output size times world size" ); |
2310 | } |
2311 | |
2312 | // @lint-ignore CLANGTIDY |
2313 | const auto& tensor = outputTensor; |
2314 | |
2315 | RECORD_PARAM_COMMS_DATA( |
2316 | static_cast<int>( |
2317 | this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective |
2318 | reinterpret_cast<std::intptr_t>(this), // process group ptr |
2319 | inputTensor, // inputTensor |
2320 | outputTensor, // outputTensor |
2321 | rank_, // rank |
2322 | "_reduce_scatter_base" , // colName |
2323 | tensor.numel() * // inSize |
2324 | this->getSize(), |
2325 | tensor.numel(), // outSize |
2326 | tensor.scalar_type(), // dtype |
2327 | std::vector<int64_t>(), // inSplitSizes |
2328 | std::vector<int64_t>()); // outSplitSizes |
2329 | |
2330 | auto inputs = std::vector<at::Tensor>{inputTensor}; |
2331 | auto outputs = std::vector<at::Tensor>{outputTensor}; |
2332 | |
2333 | int dev_in_group = 0; |
2334 | return collective( |
2335 | inputs, |
2336 | outputs, |
2337 | [&](at::Tensor& input, |
2338 | at::Tensor& output, |
2339 | ncclComm_t comm, |
2340 | at::cuda::CUDAStream& stream) { |
2341 | c10::cuda::CUDACachingAllocator::recordStream( |
2342 | output.storage().data_ptr(), stream); |
2343 | auto ncclDataType = getNcclDataType(input.scalar_type()); |
2344 | auto ncclReduceOp = getNcclReduceOp( |
2345 | opts.reduceOp, input, ncclDataType, comm, dev_in_group++); |
2346 | return ncclReduceScatter( |
2347 | input.data_ptr(), |
2348 | output.data_ptr(), |
2349 | output.numel(), |
2350 | ncclDataType, |
2351 | ncclReduceOp, |
2352 | comm, |
2353 | stream.stream()); |
2354 | }, |
2355 | [&](std::vector<at::cuda::CUDAStream>&) {}, |
2356 | [&](std::vector<at::cuda::CUDAStream>&) {}, |
2357 | OpType::_REDUCE_SCATTER_BASE, |
2358 | "nccl:_reduce_scatter_base" ); |
2359 | } |
2360 | |
2361 | c10::intrusive_ptr<Work> ProcessGroupNCCL::barrier(const BarrierOptions& opts) { |
2362 | RECORD_PARAM_COMMS( |
2363 | static_cast<int>( |
2364 | this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective |
2365 | reinterpret_cast<std::intptr_t>(this), // process group ptr |
2366 | rank_, // rank |
2367 | "barrier" , // colName |
2368 | 0, // inSize |
2369 | 0, // outSize |
2370 | at::kByte, // dType |
2371 | std::vector<int64_t>(), // inSplitSizes |
2372 | std::vector<int64_t>()); // outSplitSizes |
2373 | |
2374 | std::vector<at::Device> devices; |
2375 | |
2376 | // Use user defined GPU device ids if provided |
2377 | if (!opts.device_ids.empty()) { |
2378 | for (auto device : opts.device_ids) { |
2379 | devices.emplace_back(at::DeviceType::CUDA, device); |
2380 | } |
2381 | } else if (usedDeviceIdxs_.empty()) { |
2382 | // This means there is not yet a NCCL collective being called |
2383 | // Here we have to use the best guesses and will use a single GPU to call |
2384 | // allreduce to achieve barrier. |
2385 | // In case the multiple processes fall into the same node, we use rank to |
2386 | // ensure that each process is on a different GPU |
2387 | auto numGPUs = at::cuda::getNumGPUs(); |
2388 | int16_t deviceIdx = static_cast<int16_t>(rank_ % numGPUs); |
2389 | LOG(INFO) << c10::str( |
2390 | "Rank " , |
2391 | this->getRank(), |
2392 | " using GPU " , |
2393 | deviceIdx, |
2394 | " to perform barrier as devices used by this process are currently unknown. " , |
2395 | "This can potentially cause a hang if this rank to GPU mapping is incorrect." , |
2396 | "Specify device_ids in barrier() to force use of a particular device." ); |
2397 | devices.emplace_back(getDeviceForRank(rank_)); |
2398 | } else { |
2399 | for (auto usedDeviceIdx : usedDeviceIdxs_) { |
2400 | devices.emplace_back(at::DeviceType::CUDA, usedDeviceIdx); |
2401 | } |
2402 | } |
2403 | |
2404 | std::vector<at::Tensor> barrierTensors; |
2405 | barrierTensors.reserve(devices.size()); |
2406 | |
2407 | at::cuda::OptionalCUDAGuard gpuGuard; |
2408 | for (auto& device : devices) { |
2409 | gpuGuard.set_index(device.index()); |
2410 | barrierTensors.push_back(at::empty( |
2411 | {1}, |
2412 | at::TensorOptions().device(at::DeviceType::CUDA).dtype(at::kByte))); |
2413 | } |
2414 | |
2415 | // All reduce to achieve the barrier |
2416 | auto work = allreduce(barrierTensors); |
2417 | |
2418 | // Work will take over barrierTensors |
2419 | auto ncclWork = dynamic_cast<ProcessGroupNCCL::WorkNCCL*>(work.get()); |
2420 | TORCH_CHECK(ncclWork); |
2421 | ncclWork->barrierTensors_ = std::move(barrierTensors); |
2422 | |
2423 | return work; |
2424 | } |
2425 | |
2426 | #ifdef ENABLE_NCCL_P2P_SUPPORT |
2427 | c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall_base( |
2428 | at::Tensor& outputTensor, |
2429 | at::Tensor& inputTensor, |
2430 | std::vector<int64_t>& outputSplitSizes, |
2431 | std::vector<int64_t>& inputSplitSizes, |
2432 | const AllToAllOptions& /* unused */) { |
2433 | check_gpu_single_tensor(outputTensor); |
2434 | check_gpu_single_tensor(inputTensor); |
2435 | if (outputSplitSizes.size() == 0 && inputSplitSizes.size() == 0) { |
2436 | std::vector<at::Tensor> inputTensors = {inputTensor}; |
2437 | std::vector<at::Tensor> outputTensors = {outputTensor}; |
2438 | |
2439 | RECORD_PARAM_COMMS_DATA( |
2440 | static_cast<int>( |
2441 | this->getSequenceNumberForGroup() + |
2442 | 1), // seq + 1 to match collective |
2443 | reinterpret_cast<std::intptr_t>(this), // process group ptr |
2444 | inputTensor, // inputTensor |
2445 | outputTensor, // outputTensor |
2446 | rank_, // rank |
2447 | "all_to_all" , // colName |
2448 | inputTensor.numel(), // inSize |
2449 | outputTensor.numel(), // outSize |
2450 | inputTensor.scalar_type(), // dType |
2451 | std::vector<int64_t>(), // inSplitSizes |
2452 | std::vector<int64_t>()); // outSplitSizes |
2453 | |
2454 | return collective( |
2455 | inputTensors, |
2456 | outputTensors, |
2457 | [&](at::Tensor& input, |
2458 | at::Tensor& output, |
2459 | ncclComm_t comm, |
2460 | at::cuda::CUDAStream& stream) { |
2461 | // See [Sync Streams]. |
2462 | c10::cuda::CUDACachingAllocator::recordStream( |
2463 | output.storage().data_ptr(), stream); |
2464 | torch::cuda::nccl::all2all_single_equal_split( |
2465 | input, output, this->getSize(), comm, stream); |
2466 | return ncclSuccess; |
2467 | }, |
2468 | OpType::ALLTOALL_BASE, |
2469 | "nccl:all_to_all" ); |
2470 | } else { |
2471 | c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_); |
2472 | c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_); |
2473 | std::vector<at::Tensor> inputTensors = {inputTensor}; |
2474 | std::vector<at::Tensor> outputTensors = {outputTensor}; |
2475 | |
2476 | RECORD_PARAM_COMMS_DATA( |
2477 | static_cast<int>( |
2478 | this->getSequenceNumberForGroup() + |
2479 | 1), // seq + 1 to match collective |
2480 | reinterpret_cast<std::intptr_t>(this), // process group ptr |
2481 | inputTensor, // inputTensor |
2482 | outputTensor, // outputTensor |
2483 | rank_, // rank |
2484 | "all_to_allv" , // colName |
2485 | inputTensor.numel(), // inSize |
2486 | outputTensor.numel(), // outSize |
2487 | inputTensor.scalar_type(), // dType |
2488 | inputSplitSizes, // inSplitSizes |
2489 | outputSplitSizes); // outSplitSizes |
2490 | |
2491 | return collective( |
2492 | inputTensors, |
2493 | outputTensors, |
2494 | [&](at::Tensor& input, |
2495 | at::Tensor& output, |
2496 | ncclComm_t comm, |
2497 | at::cuda::CUDAStream& stream) { |
2498 | std::vector<size_t> send_lengths(size_); |
2499 | std::vector<size_t> recv_lengths(size_); |
2500 | std::vector<size_t> send_offsets(size_); |
2501 | std::vector<size_t> recv_offsets(size_); |
2502 | c10d::computeLengthsAndOffsets( |
2503 | inputSplitSizes, input, &send_lengths, &send_offsets); |
2504 | c10d::computeLengthsAndOffsets( |
2505 | outputSplitSizes, output, &recv_lengths, &recv_offsets); |
2506 | // See [Sync Streams]. |
2507 | c10::cuda::CUDACachingAllocator::recordStream( |
2508 | output.storage().data_ptr(), stream); |
2509 | torch::cuda::nccl::all2all_single_unequal_split( |
2510 | input.data_ptr(), |
2511 | send_lengths.data(), |
2512 | send_offsets.data(), |
2513 | output.data_ptr(), |
2514 | recv_lengths.data(), |
2515 | recv_offsets.data(), |
2516 | input.element_size(), |
2517 | input.scalar_type(), |
2518 | comm, |
2519 | stream); |
2520 | return ncclSuccess; |
2521 | }, |
2522 | OpType::ALLTOALL_BASE, |
2523 | "nccl:all_to_all" ); |
2524 | } |
2525 | } |
2526 | |
2527 | c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall( |
2528 | std::vector<at::Tensor>& outputTensors, |
2529 | std::vector<at::Tensor>& inputTensors, |
2530 | const AllToAllOptions& /* unused */) { |
2531 | auto device = outputTensors[0].device(); |
2532 | for (const auto r : c10::irange(outputTensors.size())) { |
2533 | check_gpu_single_tensor(outputTensors[r]); |
2534 | check_gpu_single_tensor(inputTensors[r]); |
2535 | TORCH_CHECK( |
2536 | device == outputTensors[r].device() && |
2537 | device == inputTensors[r].device(), |
2538 | "Tensors must be on the same device" ) |
2539 | } |
2540 | std::vector<at::Tensor> inputTensor0 = {inputTensors[0]}; |
2541 | std::vector<at::Tensor> outputTensor0 = {outputTensors[0]}; |
2542 | return collective( |
2543 | inputTensor0, |
2544 | outputTensor0, |
2545 | [&](at::Tensor& /* unused */, |
2546 | at::Tensor& /* unused */, |
2547 | ncclComm_t comm, |
2548 | at::cuda::CUDAStream& stream) { |
2549 | torch::cuda::nccl::all2all(outputTensors, inputTensors, comm, stream); |
2550 | return ncclSuccess; |
2551 | }, |
2552 | OpType::ALLTOALL); |
2553 | } |
2554 | |
2555 | c10::intrusive_ptr<Work> ProcessGroupNCCL::send( |
2556 | std::vector<at::Tensor>& tensors, |
2557 | int dstRank, |
2558 | int /* unused */) { |
2559 | check_gpu_tensors_different_devices(tensors); |
2560 | auto ret = pointToPoint( |
2561 | tensors, |
2562 | [&](at::Tensor& input, |
2563 | ncclComm_t comm, |
2564 | at::cuda::CUDAStream& stream, |
2565 | int dst) { |
2566 | torch::cuda::nccl::send(input, comm, stream, dst); |
2567 | return ncclSuccess; |
2568 | }, |
2569 | dstRank, |
2570 | OpType::SEND, |
2571 | "nccl:send" ); |
2572 | return ret; |
2573 | } |
2574 | |
2575 | c10::intrusive_ptr<Work> ProcessGroupNCCL::recv( |
2576 | std::vector<at::Tensor>& tensors, |
2577 | int srcRank, |
2578 | int /* unused */) { |
2579 | check_gpu_tensors_different_devices(tensors); |
2580 | auto ret = pointToPoint( |
2581 | tensors, |
2582 | [&](at::Tensor& output, |
2583 | ncclComm_t comm, |
2584 | at::cuda::CUDAStream& stream, |
2585 | int src) { |
2586 | torch::cuda::nccl::recv(output, comm, stream, src); |
2587 | return ncclSuccess; |
2588 | }, |
2589 | srcRank, |
2590 | OpType::RECV, |
2591 | "nccl:recv" ); |
2592 | return ret; |
2593 | } |
2594 | #else |
2595 | c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall_base( |
2596 | at::Tensor& /* unused */, |
2597 | at::Tensor& /* unused */, |
2598 | std::vector<int64_t>& /* unused */, |
2599 | std::vector<int64_t>& /* unused */, |
2600 | const AllToAllOptions& /* unused */) { |
2601 | TORCH_CHECK( |
2602 | false, |
2603 | "ProcessGroupNCCL only supports alltoall* for NCCL lib version >= 2.7.0" ); |
2604 | } |
2605 | |
2606 | c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall( |
2607 | std::vector<at::Tensor>& /* unused */, |
2608 | std::vector<at::Tensor>& /* unused */, |
2609 | const AllToAllOptions& /* unused */) { |
2610 | TORCH_CHECK( |
2611 | false, |
2612 | "ProcessGroupNCCL only supports alltoall* for NCCL lib version >= 2.7.0" ); |
2613 | } |
2614 | |
2615 | c10::intrusive_ptr<Work> ProcessGroupNCCL::send( |
2616 | std::vector<at::Tensor>& /* unused */, |
2617 | int /* unused */, |
2618 | int /* unused */) { |
2619 | TORCH_CHECK( |
2620 | false, |
2621 | "ProcessGroupNCCL only supports send for NCCL lib version >= 2.7.0" ); |
2622 | } |
2623 | |
2624 | c10::intrusive_ptr<Work> ProcessGroupNCCL::recv( |
2625 | std::vector<at::Tensor>& /* unused */, |
2626 | int /* unused */, |
2627 | int /* unused */) { |
2628 | TORCH_CHECK( |
2629 | false, |
2630 | "ProcessGroupNCCL only supports recv for NCCL lib version >= 2.7.0" ); |
2631 | } |
2632 | #endif |
2633 | |
2634 | void ProcessGroupNCCL::groupStart() { |
2635 | #if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2) |
2636 | C10D_NCCL_CHECK(ncclGroupStart(), c10::nullopt); |
2637 | #endif |
2638 | ++ncclActiveGroupCounter_; |
2639 | } |
2640 | |
2641 | void ProcessGroupNCCL::groupEnd() { |
2642 | #if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2) |
2643 | C10D_NCCL_CHECK(ncclGroupEnd(), c10::nullopt); |
2644 | #endif |
2645 | --ncclActiveGroupCounter_; |
2646 | } |
2647 | |
2648 | c10::intrusive_ptr<Work> ProcessGroupNCCL::gather( |
2649 | std::vector<std::vector<at::Tensor>>& outputTensors, |
2650 | std::vector<at::Tensor>& inputTensors, |
2651 | const GatherOptions& opts) { |
2652 | static auto invalidArgument = [](const std::string& msg) { |
2653 | TORCH_CHECK(false, "ProcessGroupNCCL::gather: " + msg); |
2654 | }; |
2655 | |
2656 | assertRootRank(invalidArgument, opts.rootRank, size_); |
2657 | check_gpu_tensors_different_devices(inputTensors); |
2658 | assertSingleElementInput(invalidArgument, inputTensors); |
2659 | |
2660 | // @lint-ignore CLANGTIDY |
2661 | auto tensor = inputTensors.back(); |
2662 | |
2663 | std::vector<at::Tensor> outputs; |
2664 | |
2665 | if (getRank() == opts.rootRank) { |
2666 | if (outputTensors.size() != 1) { |
2667 | std::stringstream ss; |
2668 | ss << "requires a single-element output list containing a list with " |
2669 | << getSize() << " tensors." ; |
2670 | invalidArgument(ss.str()); |
2671 | } else if (outputTensors[0].size() != static_cast<size_t>(getSize())) { |
2672 | std::stringstream ss; |
2673 | ss << "Incorrect output list size " << outputTensors[0].size() |
2674 | << ". Output list size should be " << getSize() |
2675 | << ", same as size of the process group." ; |
2676 | invalidArgument(ss.str()); |
2677 | } |
2678 | |
2679 | const auto& options = inputTensors[0].options(); |
2680 | const auto& sizes = inputTensors[0].sizes(); |
2681 | assertTypeAndSizesMatch(invalidArgument, outputTensors[0], options, sizes); |
2682 | outputs = outputTensors[0]; |
2683 | } else { |
2684 | // if not in the root rank, initialize outputs as empty list |
2685 | if (outputTensors.size() != 0) { |
2686 | invalidArgument("requires empty output on non-root" ); |
2687 | } |
2688 | outputs = {}; |
2689 | // append a empty tensor to the list, we don't use it but the |
2690 | // `collective` template function requires it to invoke its function |
2691 | outputs.emplace_back(); |
2692 | } |
2693 | |
2694 | RECORD_PARAM_COMMS_DATA( |
2695 | static_cast<int>( |
2696 | this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective |
2697 | reinterpret_cast<std::intptr_t>(this), // process group ptr |
2698 | inputTensors, // inputTensors |
2699 | outputTensors, // outputTensors |
2700 | rank_, // rank |
2701 | "gather" , // colName |
2702 | tensor.numel(), // inSize |
2703 | tensor.numel() * this->getSize(), // outSize |
2704 | tensor.scalar_type(), // dType |
2705 | std::vector<int64_t>(), // inSplitSizes |
2706 | std::vector<int64_t>()); // outSplitSize |
2707 | |
2708 | return collective( |
2709 | inputTensors, |
2710 | outputs, |
2711 | [&](at::Tensor& /* unused */, |
2712 | at::Tensor& /* unused */, |
2713 | ncclComm_t comm, |
2714 | at::cuda::CUDAStream& stream) { |
2715 | const auto root = opts.rootRank; |
2716 | if (getRank() == root) { |
2717 | for (auto output : outputs) { |
2718 | c10::cuda::CUDACachingAllocator::recordStream( |
2719 | output.storage().data_ptr(), stream); |
2720 | } |
2721 | } |
2722 | torch::cuda::nccl::gather(inputTensors[0], outputs, comm, stream, root); |
2723 | return ncclSuccess; |
2724 | }, |
2725 | OpType::GATHER, |
2726 | "nccl:gather" ); |
2727 | } |
2728 | |
2729 | c10::intrusive_ptr<Work> ProcessGroupNCCL::scatter( |
2730 | std::vector<at::Tensor>& outputTensors, |
2731 | std::vector<std::vector<at::Tensor>>& inputTensors, |
2732 | const ScatterOptions& opts) { |
2733 | static auto invalidArgument = [](const std::string& msg) { |
2734 | TORCH_CHECK(false, "ProcessGroupNCCL::scatter: " + msg); |
2735 | }; |
2736 | |
2737 | assertRootRank(invalidArgument, opts.rootRank, size_); |
2738 | check_gpu_tensors_different_devices(outputTensors); |
2739 | assertSingleElementInput(invalidArgument, outputTensors); |
2740 | |
2741 | // @lint-ignore CLANGTIDY |
2742 | auto tensor = outputTensors.back(); |
2743 | |
2744 | std::vector<at::Tensor> inputs; |
2745 | |
2746 | if (getRank() == opts.rootRank) { |
2747 | if (inputTensors.size() != 1) { |
2748 | std::stringstream ss; |
2749 | ss << "requires a single-element input list containing a list with " |
2750 | << getSize() << " tensors." ; |
2751 | invalidArgument(ss.str()); |
2752 | } else if (inputTensors[0].size() != static_cast<size_t>(getSize())) { |
2753 | std::stringstream ss; |
2754 | ss << "Incorrect input list size " << inputTensors[0].size() |
2755 | << ". Input list size should be " << getSize() |
2756 | << ", same as size of the process group." ; |
2757 | invalidArgument(ss.str()); |
2758 | } |
2759 | |
2760 | const auto& options = outputTensors[0].options(); |
2761 | const auto& sizes = outputTensors[0].sizes(); |
2762 | assertTypeAndSizesMatch(invalidArgument, inputTensors[0], options, sizes); |
2763 | inputs = inputTensors[0]; |
2764 | } else { |
2765 | // if not in the root rank, initialize inputTensors as empty place holder |
2766 | // with an empty list |
2767 | if (inputTensors.size() != 0) { |
2768 | invalidArgument("requires empty input on non-root" ); |
2769 | } |
2770 | inputs = {}; |
2771 | // append a empty tensor to the list, we don't use it but the |
2772 | // `collective` template function requires it to invoke its function |
2773 | inputs.emplace_back(); |
2774 | } |
2775 | |
2776 | RECORD_PARAM_COMMS_DATA( |
2777 | static_cast<int>( |
2778 | this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective |
2779 | reinterpret_cast<std::intptr_t>(this), // process group ptr |
2780 | inputTensors, // inputTensors |
2781 | outputTensors, // outputTensors |
2782 | rank_, // rank |
2783 | "scatter" , // colName |
2784 | tensor.numel(), // inSize |
2785 | tensor.numel() * this->getSize(), // outSize |
2786 | tensor.scalar_type(), // dType |
2787 | std::vector<int64_t>(), // inSplitSizes |
2788 | std::vector<int64_t>()); // outSplitSize |
2789 | |
2790 | return collective( |
2791 | outputTensors, |
2792 | inputs, |
2793 | [&](at::Tensor& /* unused */, |
2794 | at::Tensor& /* unused */, |
2795 | ncclComm_t comm, |
2796 | at::cuda::CUDAStream& stream) { |
2797 | const auto root = opts.rootRank; |
2798 | if (getRank() == root) { |
2799 | for (auto input : inputs) { |
2800 | c10::cuda::CUDACachingAllocator::recordStream( |
2801 | input.storage().data_ptr(), stream); |
2802 | } |
2803 | } |
2804 | torch::cuda::nccl::scatter( |
2805 | inputs, outputTensors[0], comm, stream, root); |
2806 | return ncclSuccess; |
2807 | }, |
2808 | OpType::SCATTER, |
2809 | "nccl:scatter" ); |
2810 | } |
2811 | |
2812 | c10::intrusive_ptr<Work> ProcessGroupNCCL::recvAnysource( |
2813 | std::vector<at::Tensor>& /* unused */, |
2814 | int /* unused */) { |
2815 | TORCH_CHECK(false, "ProcessGroupNCCL does not support recvAnysource" ); |
2816 | } |
2817 | |
2818 | c10::intrusive_ptr<Work> ProcessGroupNCCL::_allgather_base( |
2819 | at::Tensor& output_tensor, |
2820 | at::Tensor& input_tensor, |
2821 | const AllgatherOptions& /*unused */) { |
2822 | check_gpu_single_tensor(input_tensor); |
2823 | check_gpu_single_tensor(output_tensor); |
2824 | |
2825 | if (input_tensor.dtype() != output_tensor.dtype()) { |
2826 | TORCH_CHECK(false, "output tensor must have the same type as input tensor" ); |
2827 | } |
2828 | |
2829 | if (input_tensor.numel() * size_ != output_tensor.numel()) { |
2830 | TORCH_CHECK( |
2831 | false, |
2832 | "output tensor size must be equal to world_size times input tensor size" ); |
2833 | } |
2834 | |
2835 | // just a wrapper to fit the collective interface |
2836 | auto inputs = std::vector<at::Tensor>{input_tensor}; |
2837 | auto outputs = std::vector<at::Tensor>{output_tensor}; |
2838 | |
2839 | return collective( |
2840 | inputs, |
2841 | outputs, |
2842 | [&](at::Tensor& input, |
2843 | at::Tensor& output, |
2844 | ncclComm_t comm, |
2845 | at::cuda::CUDAStream& stream) { |
2846 | c10::cuda::CUDACachingAllocator::recordStream( |
2847 | output.storage().data_ptr(), stream); |
2848 | return ncclAllGather( |
2849 | input.data_ptr(), |
2850 | output.data_ptr(), |
2851 | input.numel(), |
2852 | getNcclDataType(input.scalar_type()), |
2853 | comm, |
2854 | stream.stream()); |
2855 | }, |
2856 | [&](std::vector<at::cuda::CUDAStream>&) {}, |
2857 | [&](std::vector<at::cuda::CUDAStream>&) {}, |
2858 | OpType::_ALLGATHER_BASE, |
2859 | "nccl:_all_gather_base" ); |
2860 | } |
2861 | |
2862 | #ifdef USE_NCCL_WITH_UCC |
2863 | std::shared_ptr<at::DynamicLibrary> ProcessGroupNCCL::uccLib_ = nullptr; |
2864 | #endif |
2865 | |
2866 | bool ProcessGroupNCCL::isUCCAvailable() const { |
2867 | #ifdef USE_NCCL_WITH_UCC |
2868 | return (uccPG_ != nullptr); |
2869 | #else |
2870 | return false; |
2871 | #endif |
2872 | } |
2873 | |
2874 | } // namespace c10d |
2875 | |
2876 | #endif // USE_C10D_NCCL |
2877 | |