1#pragma once
2
3#ifdef USE_C10D_NCCL
4
5#include <chrono>
6#include <iostream>
7#include <list>
8#include <mutex>
9#include <thread>
10#include <unordered_map>
11
12#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
13#include <torch/csrc/distributed/c10d/Backend.hpp>
14#include <torch/csrc/distributed/c10d/Store.hpp>
15#include <torch/csrc/distributed/c10d/UCCForNCCL.hpp>
16
17#include <ATen/DynamicLibrary.h>
18#include <ATen/cuda/CUDAContext.h>
19#include <ATen/cuda/CUDAEvent.h>
20#include <c10/core/Stream.h>
21#include <c10/core/StreamGuard.h>
22#include <c10/cuda/CUDACachingAllocator.h>
23#include <c10/cuda/CUDAGuard.h>
24#include <c10/cuda/CUDAStream.h>
25
26#include <torch/custom_class.h>
27
28namespace c10d {
29// Environment variable which controls whether we perform a NCCL healt check
30// which ensures communicators are healthy at the beginning of init.
31constexpr const char* ENABLE_NCCL_HEALTH_CHECK = "ENABLE_NCCL_HEALTH_CHECK";
32
33// Environment variable which controls whether or not wait() is blocking or
34// non-blocking.
35constexpr const char* NCCL_BLOCKING_WAIT = "NCCL_BLOCKING_WAIT";
36
37// Environment variable which controls whether or not we perform Async Error
38// Handling with NCCL.
39constexpr const char* NCCL_ASYNC_ERROR_HANDLING = "NCCL_ASYNC_ERROR_HANDLING";
40
41// Environment Variable to control whether Desync Debug is enabled.
42// This variable must be set together with NCCL_ASYNC_ERROR_HANDLING.
43constexpr const char* NCCL_DESYNC_DEBUG = "NCCL_DESYNC_DEBUG";
44
45constexpr const char* NCCL_BACKEND_NAME = "nccl";
46
47// TearDown mode: tear down process upon error, see `WorkNCCL::handleNCCLGuard`
48// Soft mode: just clean up collectives and abort communicators without tearing down process
49enum ErrorHandlingMode { NoHandling = 0, TearDown = 1, CleanUpOnly = 2 };
50
51// ProcessGroupNCCL implements NCCL bindings for c10d.
52//
53// All functions of the class are expected to be called in the same order
54// across all processes in the process group. This is the only way that we
55// can guarantee to match up the same calls among all processes.
56//
57// All NCCL functions provided by this class are asynchronous functions. More
58// specifically, each NCCL call is scheduled on a separate CUDA stream that is
59// different from the current CUDA stream. This is for the purpose of
60// achieving potentially concurrency and better performance. As a result,
61// it is the callers' responsibility to make sure that the CUDA stream their
62// code works on needs to wait for the NCCL operation from
63// this class.
64//
65// This can be done by calling:
66//
67// either WorkNCCL::wait() or WorkNCCL::synchronize(), both achieves the same
68// functionality and are synonyms.
69//
70// Also note that WorkNCCL::finishedGPUExecution() is a helper function only
71// provided by ProcessGroupNCCL to check if the NCCL operation of WorkNCCL has
72// finished execution on the GPU (not just scheduled).
73//
74// Example on using the NCCL process group
75//
76// ProcessGroupNCCL pg(store, rank, size);
77// std::shared_ptr<WorkNCCL> work = pg.allreduce(tensors);
78//
79// // At this point, NCCL kernel has already by queued successfully
80// // Now, let current stream wait for the NCCL to finish, this function is
81// // async operation as well
82//
83// work->wait()
84//
85// // Now continue on other work in the current stream.
86class TORCH_API ProcessGroupNCCL : public Backend {
87 public:
88 class WorkNCCL : public Work,
89 public std::enable_shared_from_this<WorkNCCL> {
90 public:
91 // Constructor takes a list of CUDA devices
92 WorkNCCL(
93 const std::vector<at::Device>& devices,
94 int rank,
95 OpType opType,
96 uint64_t seq,
97 const char* profilingTitle = nullptr,
98 const c10::optional<std::vector<at::Tensor>>& inputs = c10::nullopt,
99 bool desyncDebug = false);
100 // Copy constructor doing partial copy without outputs_. Cleanup thread
101 // monitors and removes finished works. However it will deadlock when
102 // destructs outputs_ tensors who are view tensors in autograd graph.
103 WorkNCCL(const WorkNCCL& w);
104
105 ~WorkNCCL() override;
106
107 // Checks if the NCCL kernel has started to execute.
108 bool isStarted();
109
110 // Checks if request has completed. In this specific case of NCCL, it checks
111 // if the NCCL operation has completed on the GPU in its own NCCL stream.
112 // Non-blocking operation.
113 bool isCompleted() override;
114
115 bool isSuccess() const override;
116
117 // Same as calling synchronize() for NCCL work.
118 bool wait(std::chrono::milliseconds timeout = kNoTimeout) override;
119
120 void abort() override;
121
122 // Let current stream wait on the completing of the NCCL work
123 // Throws on exceptions. Blocking operation, which will wait for work
124 // completion.
125 void synchronize() override;
126
127 // Synchronize streams by blocking each on the NCCL stream
128 void synchronizeStreams();
129
130 // Helper function used in CUDA Stream callbacks to complete WorkNCCL
131 // objects and throw exceptions when neeeded.
132 void handleNCCLGuard(ErrorHandlingMode asyncErrorHandling);
133
134 // Helper function that checks if the NCCL kernels have finished
135 // execution on the GPUs
136 bool finishedGPUExecution();
137
138 // Get a Future object that will be marked as completed internally.
139 c10::intrusive_ptr<c10::ivalue::Future> getFuture() override;
140
141 // Helper function that sets an exception_ptr on the WorkNCCL object.
142 void setException(std::exception_ptr exception_ptr);
143
144 // Helper function that returns True if the WorkNCCL object has timed out
145 // and False otherwise.
146 bool timedOut();
147
148 std::vector<at::Tensor> result() override;
149
150 protected:
151 // The cached list of CUDA devices to operate on
152 std::vector<at::Device> devices_;
153
154 // The start CUDA events of NCCL operator tracking this work item on
155 // multiple CUDA devices. These start CUDA events are needed by desync
156 // debugging if enabled.
157 std::shared_ptr<std::vector<at::cuda::CUDAEvent>> ncclStartEvents_;
158
159 // The end CUDA events of NCCL operator tracking this work item on
160 // multiple CUDA devices.
161 std::shared_ptr<std::vector<at::cuda::CUDAEvent>> ncclEndEvents_;
162
163 // The NCCL communicators used for this work item.
164 std::vector<std::shared_ptr<NCCLComm>> ncclComms_;
165
166 // Tensors used for barrier op
167 std::vector<at::Tensor> barrierTensors_;
168
169 // Clone of blockingWait_ from ProcessGroupNCCL.
170 bool blockingWait_ = false;
171
172 // Clone of opTimeout_ from ProcessGroupNCCL.
173 std::chrono::milliseconds opTimeout_;
174
175 // Time point representing when the work started.
176 std::chrono::time_point<std::chrono::steady_clock> workStartTime_;
177
178 // Record the collective sequential number.
179 uint64_t seq_;
180
181 // Indicates if the nccl start event has been updated to the store trace.
182 // This will be used by desync debug.
183 bool startTraceUpdated_{false};
184
185 // Wrapper method for the static checkForNCCLErrors which can be overridden
186 // for tests.
187 virtual std::exception_ptr checkForNCCLErrors(
188 const std::vector<std::shared_ptr<NCCLComm>>& ncclComms) const;
189
190 friend std::ostream& operator<<(
191 std::ostream& output,
192 const WorkNCCL& workNCCL);
193
194 private:
195 // Helper function for synchronize
196 void synchronizeInternal(std::chrono::milliseconds timeout);
197 // Checks for NCCL errors and sets an appropriate exception_ptr.
198 void checkAndSetException();
199
200 // Checks for NCCL errors and throws an appropriate exception.
201 void checkAndThrowException();
202
203 // Just checks whether GPU execution has started, without modifying
204 // exception_ptr.
205 bool startedGPUExecutionInternal() const;
206
207 // Just checks whether GPU execution has completed, without modifying
208 // exception_ptr.
209 bool finishedGPUExecutionInternal() const;
210
211 // Reference to the store so that we can write aborted communicators
212 // to the store.
213 c10::intrusive_ptr<Store> store_;
214
215 // Store a reference to NCCL collective's outputs, used by result and to
216 // give a more descriptive message when representing the Work as a string.
217 std::shared_ptr<std::vector<at::Tensor>> outputs_;
218
219 // The future returned by getFuture.
220 c10::intrusive_ptr<at::ivalue::Future> future_;
221
222 friend class ProcessGroupNCCL;
223 };
224
225 class CoalescedWorkNCCL
226 : public Work,
227 public std::enable_shared_from_this<CoalescedWorkNCCL> {
228 public:
229 // Constructor takes a list of WorkNCCL works
230 CoalescedWorkNCCL(
231 std::vector<ProcessGroupNCCL::WorkNCCL> works,
232 int rank,
233 OpType opType);
234
235 ~CoalescedWorkNCCL() override;
236
237 // Same as calling synchronize() for NCCL work.
238 bool wait(std::chrono::milliseconds timeout = kNoTimeout) override;
239
240 protected:
241 // The cached list of CUDA devices to operate on
242 std::vector<ProcessGroupNCCL::WorkNCCL> works_;
243
244 friend class ProcessGroupNCCL;
245 };
246
247 struct Options : Backend::Options {
248 // NOTE: timeout in ProcessGroupNCCL::Options denote the timeout for
249 // operations. This is only used when blockingWait_ is enabled.
250 explicit Options(
251 bool is_high_priority_stream = false);
252
253 // return intrusive_ptr of the object
254 static c10::intrusive_ptr<Options> create(
255 bool is_high_priority_stream = false) {
256 return c10::make_intrusive<Options>(is_high_priority_stream);
257 }
258
259 // Schedule NCCL operations on high priority CUDA streams
260 bool is_high_priority_stream;
261 };
262
263 // If you wish to create multiple process groups, each with a potentially
264 // different rank and size, you can do so by passing a new store instance
265 // to each one. If you have only a single store object, you can
266 // use the `c10d::PrefixStore` to derive scoped instances.
267 // This is also what the Python API in torch.distributed does.
268 //
269 // The process group instance keeps a reference to the store because
270 // it may be used long after the constructor runs. In fact, the constructor
271 // doesn't create any NCCL communicators. A single NCCL communicator can
272 // only be used on a specific set of devices, and are therefore created
273 // on-demand when a collective runs. If another collective is executed later,
274 // against a different set of devices, the process group creates another NCCL
275 // communicator. These NCCL communicators are cached and reused if possible.
276 //
277 ProcessGroupNCCL(
278 const c10::intrusive_ptr<Store>& store,
279 int rank,
280 int size,
281 c10::intrusive_ptr<Options> options = Options::create());
282
283 // This constructor includes the deprecated `groupName` argument.
284 // If you have existing code that uses the `groupName`, you can replace
285 // it by specifying a `c10d::PrefixStore(groupName, store)` for store.
286 C10_DEPRECATED ProcessGroupNCCL(
287 const c10::intrusive_ptr<Store>& store,
288 int rank,
289 int size,
290 const std::string& groupName,
291 c10::intrusive_ptr<Options> options = Options::create())
292 : ProcessGroupNCCL(store, rank, size, options) {}
293
294 ~ProcessGroupNCCL() override;
295
296 c10::intrusive_ptr<Options> getOptions() {
297 return options_;
298 }
299
300 const std::string getBackendName() const override {
301 return std::string(NCCL_BACKEND_NAME);
302 }
303
304 void startCoalescing() override;
305
306 void endCoalescing(
307 std::vector<c10::intrusive_ptr<Work>>& reqs) override;
308
309 c10::intrusive_ptr<Work> broadcast(
310 std::vector<at::Tensor>& tensors,
311 const BroadcastOptions& opts = BroadcastOptions()) override;
312
313 c10::intrusive_ptr<Work> _broadcast_oop(
314 std::vector<at::Tensor>& outputTensors,
315 std::vector<at::Tensor>& inputTensors,
316 const BroadcastOptions& opts = BroadcastOptions());
317
318 c10::intrusive_ptr<Work> allreduce(
319 std::vector<at::Tensor>& tensors,
320 const AllreduceOptions& opts = AllreduceOptions()) override;
321
322 c10::intrusive_ptr<Work> allreduce_coalesced(
323 std::vector<at::Tensor>& tensors,
324 const AllreduceCoalescedOptions& opts =
325 AllreduceCoalescedOptions()) override;
326
327 c10::intrusive_ptr<Work> reduce(
328 std::vector<at::Tensor>& tensors,
329 const ReduceOptions& opts = ReduceOptions()) override;
330
331 c10::intrusive_ptr<Work> _reduce_oop(
332 std::vector<at::Tensor>& outputTensors,
333 std::vector<at::Tensor>& inputTensors,
334 const ReduceOptions& opts = ReduceOptions());
335
336 c10::intrusive_ptr<Work> allgather(
337 std::vector<std::vector<at::Tensor>>& outputTensors,
338 std::vector<at::Tensor>& inputTensors,
339 const AllgatherOptions& opts = AllgatherOptions()) override;
340
341 c10::intrusive_ptr<Work> _allgather_base(
342 at::Tensor& outputbuffer,
343 at::Tensor& inputbuffer,
344 const AllgatherOptions& opts = AllgatherOptions()) override;
345
346 c10::intrusive_ptr<Work> allgather_coalesced(
347 std::vector<std::vector<at::Tensor>>& outputTensorLists,
348 std::vector<at::Tensor>& inputTensors,
349 const AllgatherOptions& opts = AllgatherOptions()) override;
350
351 c10::intrusive_ptr<Work> reduce_scatter(
352 std::vector<at::Tensor>& outputTensors,
353 std::vector<std::vector<at::Tensor>>& inputTensors,
354 const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
355
356 c10::intrusive_ptr<Work> _reduce_scatter_base(
357 at::Tensor& outputTensor,
358 at::Tensor& inputTensor,
359 const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
360
361 c10::intrusive_ptr<Work> barrier(
362 const BarrierOptions& opts = BarrierOptions()) override;
363
364 c10::intrusive_ptr<Work> alltoall_base(
365 at::Tensor& outputTensor,
366 at::Tensor& inputTensor,
367 std::vector<int64_t>& outputSplitSizes,
368 std::vector<int64_t>& inputSplitSizes,
369 const AllToAllOptions& opts = AllToAllOptions()) override;
370
371 c10::intrusive_ptr<Work> alltoall(
372 std::vector<at::Tensor>& outputTensors,
373 std::vector<at::Tensor>& inputTensors,
374 const AllToAllOptions& opts = AllToAllOptions()) override;
375
376 c10::intrusive_ptr<Work> send(
377 std::vector<at::Tensor>& tensors,
378 int dstRank,
379 int tag) override;
380
381 c10::intrusive_ptr<Work> recv(
382 std::vector<at::Tensor>& tensors,
383 int srcRank,
384 int tag) override;
385
386 static void groupStart();
387
388 static void groupEnd();
389
390 // Unsupported Ops
391 c10::intrusive_ptr<Work> gather(
392 std::vector<std::vector<at::Tensor>>& outputTensors,
393 std::vector<at::Tensor>& inputTensors,
394 const GatherOptions& opts = GatherOptions()) override;
395
396 c10::intrusive_ptr<Work> scatter(
397 std::vector<at::Tensor>& outputTensors,
398 std::vector<std::vector<at::Tensor>>& inputTensors,
399 const ScatterOptions& opts = ScatterOptions()) override;
400
401 c10::intrusive_ptr<Work> recvAnysource(
402 std::vector<at::Tensor>& tensors,
403 int tag) override;
404
405 // Agrees on an initial sequence number for the whole group by having rank 0
406 // create it and broadcast it to other ranks using the store.
407 void setSequenceNumberForGroup() override;
408
409 // Retrieves the current sequence number for the whole group, which should be
410 // in sync. If the returned number is not consistent across the group, it
411 // may indicate that there is some sort of collective desynchronization.
412 uint64_t getSequenceNumberForGroup() override;
413
414 // Tests if the UCC fallback path is available
415 bool isUCCAvailable() const;
416
417 protected:
418 // Helper that broadcasts nccl unique ID to all ranks through the store
419 void broadcastUniqueNCCLID(
420 ncclUniqueId* ncclID,
421 bool isSingleP2POp,
422 const std::string& devicesKey,
423 int p2pRank);
424
425 // Helper that either looks up the cached NCCL communicators or creates
426 // a new set of NCCL communicators as a cache entry
427 std::vector<std::shared_ptr<NCCLComm>>& getNCCLComm(
428 const std::string& devicesKey,
429 const std::vector<at::Device>& devices,
430 OpType opType,
431 int p2pRank = 0,
432 bool isSendRecvSelf = false);
433
434 // Wrapper method which can be overridden for tests.
435 virtual std::exception_ptr checkForNCCLErrors(
436 const std::vector<std::shared_ptr<NCCLComm>>& ncclComms);
437
438 virtual c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> initWork(
439 std::vector<at::Device> devices,
440 int rank,
441 OpType opType,
442 const char* profilingTitle=nullptr,
443 const c10::optional<std::vector<at::Tensor>>& inputs = c10::nullopt);
444
445 virtual c10::intrusive_ptr<ProcessGroupNCCL::CoalescedWorkNCCL>
446 initCoalescedWork(
447 const std::vector<c10::intrusive_ptr<Work>>& works,
448 int rank,
449 OpType opType);
450
451 private:
452 // Helper that encapsulates work shared across all collective communication
453 // primitives. The callbacks have the following signatures:
454 //
455 // ncclResult_t fn(at::Tensor& input, at::Tensor& output,
456 // ncclComm_t, at::cuda::CUDAStream&);
457 // void {pre,post}(std::vector<at::cuda::CUDAStream&>);
458 template <typename Fn>
459 c10::intrusive_ptr<Work> collective(
460 std::vector<at::Tensor>& input,
461 std::vector<at::Tensor>& output,
462 Fn fn,
463 OpType opType,
464 const char* profilingTitle = nullptr);
465 template <typename Fn, typename PreProcess, typename PostProcess>
466 c10::intrusive_ptr<Work> collective(
467 std::vector<at::Tensor>& input,
468 std::vector<at::Tensor>& output,
469 Fn fn,
470 PreProcess pre,
471 PostProcess post,
472 OpType opType,
473 const char* profilingTitle = nullptr);
474
475 // Helper that encapsulates work shared across point-to-point communication
476 // primitives. It is the same structure as the helper used for collective
477 // communicaiton primitives.
478 template <typename Fn>
479 c10::intrusive_ptr<Work> pointToPoint(
480 std::vector<at::Tensor>& tensor,
481 Fn fn,
482 int peer,
483 OpType opType,
484 const char* profilingTitle = nullptr);
485 template <typename Fn, typename PreProcess, typename PostProcess>
486 c10::intrusive_ptr<Work> pointToPoint(
487 std::vector<at::Tensor>& tensor,
488 Fn fn,
489 int peer,
490 OpType opType,
491 PreProcess pre,
492 PostProcess post,
493 const char* profilingTitle);
494
495 c10::intrusive_ptr<Work> allreduce_impl(
496 std::vector<at::Tensor>& tensors,
497 const AllreduceOptions& opts = AllreduceOptions());
498
499 // Checks for NCCL errors on each of the communicators and returns an
500 // appropriate exception_ptr (nullptr if no errors).
501 static std::exception_ptr checkForNCCLErrorsInternal(
502 const std::vector<std::shared_ptr<NCCLComm>>& ncclComms);
503
504 // Function that runs as part of a separate thread and checks for errors on
505 // NCCL communicators. We need a separate thread to check for NCCL errors
506 // since we can't rely on the user calling certain methods like wait(),
507 // isCompleted() etc. to detect and remediate errors. In addition to this, we
508 // need a mechanism to safely abort and remove NCCL communicators from our
509 // cache. This can be done cleanly by having a thread for the ProcessGroupNCCL
510 // class. Attempting to modify the communicator cache from the WorkNCCL class
511 // might run into issues with object lifetime since the ProcessGroupNCCL
512 // object might get destroyed before the WorkNCCL object.
513 void ncclCommWatchdog();
514
515 void ncclCommWatchdogInternal();
516
517 // This function iterates through the list of WorkNCCL objects in the
518 // workList_ corresponding to incomplete collectives and then aborts NCCL
519 // communicators associated with timed out collectives.
520 void abortTimedOutCollectives(
521 std::unordered_set<std::string>& abortedCommIds);
522
523 // Performs a health check by initializing dummy NCCL communicators and then
524 // destroying them. This will help indicate and signal any NCCL-related issues
525 // prior to the first collective. The actual initialization and subsequent
526 // destruction is ran on a separate thread and the main thread is signalled
527 // about timeouts/errors to report to the application.
528 void runHealthCheck();
529
530 // Destroys initialized NCCL communicators in devNCCLComMap_ given by input
531 // key. Throws if there are no communicators to destroy. Also removes
532 // communicators from the cache and clears used device indices.
533 void destroyNCCLComms(const std::string& devNCCLCommMapKey);
534
535 void workCleanupLoop();
536
537 protected:
538 static const int64_t kWatchdogThreadSleepMillis;
539 static const int64_t kWorkCleanupThreadSleepMillis;
540
541 // The store is used to broadcast the NCCL unique ID of rank 0.
542 c10::intrusive_ptr<Store> store_;
543
544 bool storeError_{false};
545
546 const c10::intrusive_ptr<Options> options_;
547
548 // The number of NCCL communicators that have been created during
549 // the lifetime of this process group. This sequence number is
550 // used to scope keys used in the store.
551 uint64_t ncclCommCounter_{0};
552
553 // The store keys to trace the last NCCL collective kernel CUDA events - start
554 // event and end event respectively. These are used to do desync root cause
555 // analysis.
556 const std::string traceKeyStart_;
557 const std::string traceKeyEnd_;
558
559 // The NCCL communicator that the process group has cached.
560 //
561 // For collective operations:
562 // The key is a list of GPU devices that an operation is operating on
563 // The GPU devices are stored in a device sequence and the cache NCCL
564 // communicator is associated with this GPU device sequence
565 //
566 // e.g. If the process group op only uses device 0, then the value of
567 // the used device string stored (value of the hashmap) would be "0".
568 //
569 // If the process group op uses device 0 - 7 and the each tensor of the
570 // input tensor list is on device, 0, 1, 2, 3, 4, 5, 6, 7 separately,
571 // then the value of the used device string (key) stored would be
572 // "0,1,2,3,4,5,6,7"
573 //
574 // If the process group op uses device 0 - 7 and the each tensor of the
575 // input tensor list is on device, 0, 4, 5, 6, 7, 1, 2, 3 separately,
576 // then the value of the used device string stored would be
577 // "0,4,5,6,7,1,2,3"
578 //
579 // Note that the order of the device for the tensor list matters.
580 //
581 // For point-to-point operations:
582 // The key is a string of my current rank and the peer process rank.
583 // e.g. If process 1 and process 2 are involved in a point-to-point
584 // communication, the key will be "1:2" on both processes. Note: this is for
585 // the scenario where there is only 1 GPU per process. When it comes to
586 // multiple GPUs per process, this part may need to redesigned.
587 std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLComm>>>
588 devNCCLCommMap_;
589
590 // Map from ncclUniqueId to appropriate communicator.
591 std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLComm>>>
592 ncclIdToCommMap_;
593
594 // Mutex to guard maps like devNCCLCommMap_ and ncclIdToCommMap_.
595 std::mutex mutex_;
596
597 // Watchdog thread which looks for errors on the cached NCCL communicators.
598 std::thread ncclCommWatchdogThread_;
599
600 // Whether or not we should terminate the watchdog and workCleanup threads.
601 std::atomic<bool> terminateProcessGroup_;
602
603 // Condition variable to control how long the watchdog thread waits.
604 std::condition_variable watchdogCV_;
605
606 // Mutex for watchdog.
607 std::mutex watchdogCVMutex_;
608
609 // Thread that removes NCCL Work upon timeout
610 std::thread workCleanupThread_;
611
612 // Mutex to Guard workMetaList_
613 std::mutex workMetaListMutex_;
614
615 // Condition Variable for timeout thread sleep
616 std::condition_variable workMetaListCV_;
617
618 // Vector to Store WorkNCCL pointers
619 std::list<ProcessGroupNCCL::WorkNCCL> workMetaList_;
620
621 // Add Work Pointer to workVector
622 void workEnqueue(c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>);
623
624 // The CUDA steams used by NCCL kernels
625 std::unordered_map<std::string, std::vector<at::cuda::CUDAStream>>
626 ncclStreams_;
627
628 // The CUDA events used to sync NCCL streams
629 std::unordered_map<std::string, std::vector<at::cuda::CUDAEvent>> ncclEvents_;
630
631 // Device Indexes used for all collectives in this group
632 std::set<int> usedDeviceIdxs_;
633
634 // Flag to denote if a coalescing groupStart/groupEnd block is active
635 bool coalescing_active_ = false;
636
637 // Stores device indexes for all collectives run inside a coalescing block
638 std::vector<std::vector<at::Device>> coalescedDevices_;
639
640 // map from the key: "group name + pg counter (ID)" to the
641 // unique NCCL ID count. This needs to be group and pg specific
642 //
643 // For each process group, we need a uniform unique NCCL ID counter to ensure
644 // that NCCL operation in this process group can be completed successfully.
645 // Since each process group ID belongs to a group name, the key to this map
646 // is a combination of group name and ProcessGroupNCCL ID.
647 static std::unordered_map<std::string, ssize_t> pgUniqueNCCLIDCnt_;
648
649 // map from group name to the pg counter (ID) within that group
650 //
651 // For each group with the "group name" (which is the key), we need to
652 // keep track of a unique process group ID when creating a new
653 // ProcessGroupNCCL for this "group name". Therefore, the value of this
654 // map keeps the unique ProcessGroupNCCL's ID for a specific group with
655 // the "group name". The reason we need a per-group process group ID counter
656 // is that different group can have different ranks and we need ensure that
657 // each group has its own uniform process group ID for all its ranks.
658 static std::unordered_map<std::string, ssize_t> processGroupCounterMap_;
659
660 // Whether or not wait() and synchronize() are blocking operations that wait
661 // for the operation to complete.
662 bool blockingWait_ = false;
663
664 // Whether or not the workCleanupThread is used to perform async error
665 // handling.
666 ErrorHandlingMode asyncErrorHandling_ = NoHandling;
667
668 // Whether or not to enable timeout root cause analysis.
669 bool desyncDebug_;
670
671 // Set of communicators that this process group has aborted and their
672 // ncclUniqueId has been written to the store. We don't need a lock
673 // for this map since only the watchdog thread accesses this set. The
674 // set contains the string representation of ncclUniqueId.
675 std::unordered_set<std::string> abortedComms_;
676
677 // The number of active ncclGroupStart() calls. This counter will be increased
678 // by 1 when ncclGroupStart() is called and decreased by 1 when ncclGroupEnd()
679 // is called.
680 static thread_local uint64_t ncclActiveGroupCounter_;
681
682 // Counting for the sequential number of NCCL collective call.
683 uint64_t seq_{0};
684
685#ifdef USE_NCCL_WITH_UCC
686 // ProcessGroupUCC shared library handle and ProcessGroup pointer
687 static std::shared_ptr<at::DynamicLibrary> uccLib_;
688 c10::intrusive_ptr<Backend> uccPG_;
689#endif
690};
691
692} // namespace c10d
693
694#endif // USE_C10D_NCCL
695