1 | #pragma once |
2 | |
3 | #ifdef USE_C10D_NCCL |
4 | |
5 | #include <chrono> |
6 | #include <iostream> |
7 | #include <list> |
8 | #include <mutex> |
9 | #include <thread> |
10 | #include <unordered_map> |
11 | |
12 | #include <torch/csrc/distributed/c10d/NCCLUtils.hpp> |
13 | #include <torch/csrc/distributed/c10d/Backend.hpp> |
14 | #include <torch/csrc/distributed/c10d/Store.hpp> |
15 | #include <torch/csrc/distributed/c10d/UCCForNCCL.hpp> |
16 | |
17 | #include <ATen/DynamicLibrary.h> |
18 | #include <ATen/cuda/CUDAContext.h> |
19 | #include <ATen/cuda/CUDAEvent.h> |
20 | #include <c10/core/Stream.h> |
21 | #include <c10/core/StreamGuard.h> |
22 | #include <c10/cuda/CUDACachingAllocator.h> |
23 | #include <c10/cuda/CUDAGuard.h> |
24 | #include <c10/cuda/CUDAStream.h> |
25 | |
26 | #include <torch/custom_class.h> |
27 | |
28 | namespace c10d { |
29 | // Environment variable which controls whether we perform a NCCL healt check |
30 | // which ensures communicators are healthy at the beginning of init. |
31 | constexpr const char* ENABLE_NCCL_HEALTH_CHECK = "ENABLE_NCCL_HEALTH_CHECK" ; |
32 | |
33 | // Environment variable which controls whether or not wait() is blocking or |
34 | // non-blocking. |
35 | constexpr const char* NCCL_BLOCKING_WAIT = "NCCL_BLOCKING_WAIT" ; |
36 | |
37 | // Environment variable which controls whether or not we perform Async Error |
38 | // Handling with NCCL. |
39 | constexpr const char* NCCL_ASYNC_ERROR_HANDLING = "NCCL_ASYNC_ERROR_HANDLING" ; |
40 | |
41 | // Environment Variable to control whether Desync Debug is enabled. |
42 | // This variable must be set together with NCCL_ASYNC_ERROR_HANDLING. |
43 | constexpr const char* NCCL_DESYNC_DEBUG = "NCCL_DESYNC_DEBUG" ; |
44 | |
45 | constexpr const char* NCCL_BACKEND_NAME = "nccl" ; |
46 | |
47 | // TearDown mode: tear down process upon error, see `WorkNCCL::handleNCCLGuard` |
48 | // Soft mode: just clean up collectives and abort communicators without tearing down process |
49 | enum ErrorHandlingMode { NoHandling = 0, TearDown = 1, CleanUpOnly = 2 }; |
50 | |
51 | // ProcessGroupNCCL implements NCCL bindings for c10d. |
52 | // |
53 | // All functions of the class are expected to be called in the same order |
54 | // across all processes in the process group. This is the only way that we |
55 | // can guarantee to match up the same calls among all processes. |
56 | // |
57 | // All NCCL functions provided by this class are asynchronous functions. More |
58 | // specifically, each NCCL call is scheduled on a separate CUDA stream that is |
59 | // different from the current CUDA stream. This is for the purpose of |
60 | // achieving potentially concurrency and better performance. As a result, |
61 | // it is the callers' responsibility to make sure that the CUDA stream their |
62 | // code works on needs to wait for the NCCL operation from |
63 | // this class. |
64 | // |
65 | // This can be done by calling: |
66 | // |
67 | // either WorkNCCL::wait() or WorkNCCL::synchronize(), both achieves the same |
68 | // functionality and are synonyms. |
69 | // |
70 | // Also note that WorkNCCL::finishedGPUExecution() is a helper function only |
71 | // provided by ProcessGroupNCCL to check if the NCCL operation of WorkNCCL has |
72 | // finished execution on the GPU (not just scheduled). |
73 | // |
74 | // Example on using the NCCL process group |
75 | // |
76 | // ProcessGroupNCCL pg(store, rank, size); |
77 | // std::shared_ptr<WorkNCCL> work = pg.allreduce(tensors); |
78 | // |
79 | // // At this point, NCCL kernel has already by queued successfully |
80 | // // Now, let current stream wait for the NCCL to finish, this function is |
81 | // // async operation as well |
82 | // |
83 | // work->wait() |
84 | // |
85 | // // Now continue on other work in the current stream. |
86 | class TORCH_API ProcessGroupNCCL : public Backend { |
87 | public: |
88 | class WorkNCCL : public Work, |
89 | public std::enable_shared_from_this<WorkNCCL> { |
90 | public: |
91 | // Constructor takes a list of CUDA devices |
92 | WorkNCCL( |
93 | const std::vector<at::Device>& devices, |
94 | int rank, |
95 | OpType opType, |
96 | uint64_t seq, |
97 | const char* profilingTitle = nullptr, |
98 | const c10::optional<std::vector<at::Tensor>>& inputs = c10::nullopt, |
99 | bool desyncDebug = false); |
100 | // Copy constructor doing partial copy without outputs_. Cleanup thread |
101 | // monitors and removes finished works. However it will deadlock when |
102 | // destructs outputs_ tensors who are view tensors in autograd graph. |
103 | WorkNCCL(const WorkNCCL& w); |
104 | |
105 | ~WorkNCCL() override; |
106 | |
107 | // Checks if the NCCL kernel has started to execute. |
108 | bool isStarted(); |
109 | |
110 | // Checks if request has completed. In this specific case of NCCL, it checks |
111 | // if the NCCL operation has completed on the GPU in its own NCCL stream. |
112 | // Non-blocking operation. |
113 | bool isCompleted() override; |
114 | |
115 | bool isSuccess() const override; |
116 | |
117 | // Same as calling synchronize() for NCCL work. |
118 | bool wait(std::chrono::milliseconds timeout = kNoTimeout) override; |
119 | |
120 | void abort() override; |
121 | |
122 | // Let current stream wait on the completing of the NCCL work |
123 | // Throws on exceptions. Blocking operation, which will wait for work |
124 | // completion. |
125 | void synchronize() override; |
126 | |
127 | // Synchronize streams by blocking each on the NCCL stream |
128 | void synchronizeStreams(); |
129 | |
130 | // Helper function used in CUDA Stream callbacks to complete WorkNCCL |
131 | // objects and throw exceptions when neeeded. |
132 | void handleNCCLGuard(ErrorHandlingMode asyncErrorHandling); |
133 | |
134 | // Helper function that checks if the NCCL kernels have finished |
135 | // execution on the GPUs |
136 | bool finishedGPUExecution(); |
137 | |
138 | // Get a Future object that will be marked as completed internally. |
139 | c10::intrusive_ptr<c10::ivalue::Future> getFuture() override; |
140 | |
141 | // Helper function that sets an exception_ptr on the WorkNCCL object. |
142 | void setException(std::exception_ptr exception_ptr); |
143 | |
144 | // Helper function that returns True if the WorkNCCL object has timed out |
145 | // and False otherwise. |
146 | bool timedOut(); |
147 | |
148 | std::vector<at::Tensor> result() override; |
149 | |
150 | protected: |
151 | // The cached list of CUDA devices to operate on |
152 | std::vector<at::Device> devices_; |
153 | |
154 | // The start CUDA events of NCCL operator tracking this work item on |
155 | // multiple CUDA devices. These start CUDA events are needed by desync |
156 | // debugging if enabled. |
157 | std::shared_ptr<std::vector<at::cuda::CUDAEvent>> ncclStartEvents_; |
158 | |
159 | // The end CUDA events of NCCL operator tracking this work item on |
160 | // multiple CUDA devices. |
161 | std::shared_ptr<std::vector<at::cuda::CUDAEvent>> ncclEndEvents_; |
162 | |
163 | // The NCCL communicators used for this work item. |
164 | std::vector<std::shared_ptr<NCCLComm>> ncclComms_; |
165 | |
166 | // Tensors used for barrier op |
167 | std::vector<at::Tensor> barrierTensors_; |
168 | |
169 | // Clone of blockingWait_ from ProcessGroupNCCL. |
170 | bool blockingWait_ = false; |
171 | |
172 | // Clone of opTimeout_ from ProcessGroupNCCL. |
173 | std::chrono::milliseconds opTimeout_; |
174 | |
175 | // Time point representing when the work started. |
176 | std::chrono::time_point<std::chrono::steady_clock> workStartTime_; |
177 | |
178 | // Record the collective sequential number. |
179 | uint64_t seq_; |
180 | |
181 | // Indicates if the nccl start event has been updated to the store trace. |
182 | // This will be used by desync debug. |
183 | bool startTraceUpdated_{false}; |
184 | |
185 | // Wrapper method for the static checkForNCCLErrors which can be overridden |
186 | // for tests. |
187 | virtual std::exception_ptr checkForNCCLErrors( |
188 | const std::vector<std::shared_ptr<NCCLComm>>& ncclComms) const; |
189 | |
190 | friend std::ostream& operator<<( |
191 | std::ostream& output, |
192 | const WorkNCCL& workNCCL); |
193 | |
194 | private: |
195 | // Helper function for synchronize |
196 | void synchronizeInternal(std::chrono::milliseconds timeout); |
197 | // Checks for NCCL errors and sets an appropriate exception_ptr. |
198 | void checkAndSetException(); |
199 | |
200 | // Checks for NCCL errors and throws an appropriate exception. |
201 | void checkAndThrowException(); |
202 | |
203 | // Just checks whether GPU execution has started, without modifying |
204 | // exception_ptr. |
205 | bool startedGPUExecutionInternal() const; |
206 | |
207 | // Just checks whether GPU execution has completed, without modifying |
208 | // exception_ptr. |
209 | bool finishedGPUExecutionInternal() const; |
210 | |
211 | // Reference to the store so that we can write aborted communicators |
212 | // to the store. |
213 | c10::intrusive_ptr<Store> store_; |
214 | |
215 | // Store a reference to NCCL collective's outputs, used by result and to |
216 | // give a more descriptive message when representing the Work as a string. |
217 | std::shared_ptr<std::vector<at::Tensor>> outputs_; |
218 | |
219 | // The future returned by getFuture. |
220 | c10::intrusive_ptr<at::ivalue::Future> future_; |
221 | |
222 | friend class ProcessGroupNCCL; |
223 | }; |
224 | |
225 | class CoalescedWorkNCCL |
226 | : public Work, |
227 | public std::enable_shared_from_this<CoalescedWorkNCCL> { |
228 | public: |
229 | // Constructor takes a list of WorkNCCL works |
230 | CoalescedWorkNCCL( |
231 | std::vector<ProcessGroupNCCL::WorkNCCL> works, |
232 | int rank, |
233 | OpType opType); |
234 | |
235 | ~CoalescedWorkNCCL() override; |
236 | |
237 | // Same as calling synchronize() for NCCL work. |
238 | bool wait(std::chrono::milliseconds timeout = kNoTimeout) override; |
239 | |
240 | protected: |
241 | // The cached list of CUDA devices to operate on |
242 | std::vector<ProcessGroupNCCL::WorkNCCL> works_; |
243 | |
244 | friend class ProcessGroupNCCL; |
245 | }; |
246 | |
247 | struct Options : Backend::Options { |
248 | // NOTE: timeout in ProcessGroupNCCL::Options denote the timeout for |
249 | // operations. This is only used when blockingWait_ is enabled. |
250 | explicit Options( |
251 | bool is_high_priority_stream = false); |
252 | |
253 | // return intrusive_ptr of the object |
254 | static c10::intrusive_ptr<Options> create( |
255 | bool is_high_priority_stream = false) { |
256 | return c10::make_intrusive<Options>(is_high_priority_stream); |
257 | } |
258 | |
259 | // Schedule NCCL operations on high priority CUDA streams |
260 | bool is_high_priority_stream; |
261 | }; |
262 | |
263 | // If you wish to create multiple process groups, each with a potentially |
264 | // different rank and size, you can do so by passing a new store instance |
265 | // to each one. If you have only a single store object, you can |
266 | // use the `c10d::PrefixStore` to derive scoped instances. |
267 | // This is also what the Python API in torch.distributed does. |
268 | // |
269 | // The process group instance keeps a reference to the store because |
270 | // it may be used long after the constructor runs. In fact, the constructor |
271 | // doesn't create any NCCL communicators. A single NCCL communicator can |
272 | // only be used on a specific set of devices, and are therefore created |
273 | // on-demand when a collective runs. If another collective is executed later, |
274 | // against a different set of devices, the process group creates another NCCL |
275 | // communicator. These NCCL communicators are cached and reused if possible. |
276 | // |
277 | ProcessGroupNCCL( |
278 | const c10::intrusive_ptr<Store>& store, |
279 | int rank, |
280 | int size, |
281 | c10::intrusive_ptr<Options> options = Options::create()); |
282 | |
283 | // This constructor includes the deprecated `groupName` argument. |
284 | // If you have existing code that uses the `groupName`, you can replace |
285 | // it by specifying a `c10d::PrefixStore(groupName, store)` for store. |
286 | C10_DEPRECATED ProcessGroupNCCL( |
287 | const c10::intrusive_ptr<Store>& store, |
288 | int rank, |
289 | int size, |
290 | const std::string& groupName, |
291 | c10::intrusive_ptr<Options> options = Options::create()) |
292 | : ProcessGroupNCCL(store, rank, size, options) {} |
293 | |
294 | ~ProcessGroupNCCL() override; |
295 | |
296 | c10::intrusive_ptr<Options> getOptions() { |
297 | return options_; |
298 | } |
299 | |
300 | const std::string getBackendName() const override { |
301 | return std::string(NCCL_BACKEND_NAME); |
302 | } |
303 | |
304 | void startCoalescing() override; |
305 | |
306 | void endCoalescing( |
307 | std::vector<c10::intrusive_ptr<Work>>& reqs) override; |
308 | |
309 | c10::intrusive_ptr<Work> broadcast( |
310 | std::vector<at::Tensor>& tensors, |
311 | const BroadcastOptions& opts = BroadcastOptions()) override; |
312 | |
313 | c10::intrusive_ptr<Work> _broadcast_oop( |
314 | std::vector<at::Tensor>& outputTensors, |
315 | std::vector<at::Tensor>& inputTensors, |
316 | const BroadcastOptions& opts = BroadcastOptions()); |
317 | |
318 | c10::intrusive_ptr<Work> allreduce( |
319 | std::vector<at::Tensor>& tensors, |
320 | const AllreduceOptions& opts = AllreduceOptions()) override; |
321 | |
322 | c10::intrusive_ptr<Work> allreduce_coalesced( |
323 | std::vector<at::Tensor>& tensors, |
324 | const AllreduceCoalescedOptions& opts = |
325 | AllreduceCoalescedOptions()) override; |
326 | |
327 | c10::intrusive_ptr<Work> reduce( |
328 | std::vector<at::Tensor>& tensors, |
329 | const ReduceOptions& opts = ReduceOptions()) override; |
330 | |
331 | c10::intrusive_ptr<Work> _reduce_oop( |
332 | std::vector<at::Tensor>& outputTensors, |
333 | std::vector<at::Tensor>& inputTensors, |
334 | const ReduceOptions& opts = ReduceOptions()); |
335 | |
336 | c10::intrusive_ptr<Work> allgather( |
337 | std::vector<std::vector<at::Tensor>>& outputTensors, |
338 | std::vector<at::Tensor>& inputTensors, |
339 | const AllgatherOptions& opts = AllgatherOptions()) override; |
340 | |
341 | c10::intrusive_ptr<Work> _allgather_base( |
342 | at::Tensor& outputbuffer, |
343 | at::Tensor& inputbuffer, |
344 | const AllgatherOptions& opts = AllgatherOptions()) override; |
345 | |
346 | c10::intrusive_ptr<Work> allgather_coalesced( |
347 | std::vector<std::vector<at::Tensor>>& outputTensorLists, |
348 | std::vector<at::Tensor>& inputTensors, |
349 | const AllgatherOptions& opts = AllgatherOptions()) override; |
350 | |
351 | c10::intrusive_ptr<Work> reduce_scatter( |
352 | std::vector<at::Tensor>& outputTensors, |
353 | std::vector<std::vector<at::Tensor>>& inputTensors, |
354 | const ReduceScatterOptions& opts = ReduceScatterOptions()) override; |
355 | |
356 | c10::intrusive_ptr<Work> _reduce_scatter_base( |
357 | at::Tensor& outputTensor, |
358 | at::Tensor& inputTensor, |
359 | const ReduceScatterOptions& opts = ReduceScatterOptions()) override; |
360 | |
361 | c10::intrusive_ptr<Work> barrier( |
362 | const BarrierOptions& opts = BarrierOptions()) override; |
363 | |
364 | c10::intrusive_ptr<Work> alltoall_base( |
365 | at::Tensor& outputTensor, |
366 | at::Tensor& inputTensor, |
367 | std::vector<int64_t>& outputSplitSizes, |
368 | std::vector<int64_t>& inputSplitSizes, |
369 | const AllToAllOptions& opts = AllToAllOptions()) override; |
370 | |
371 | c10::intrusive_ptr<Work> alltoall( |
372 | std::vector<at::Tensor>& outputTensors, |
373 | std::vector<at::Tensor>& inputTensors, |
374 | const AllToAllOptions& opts = AllToAllOptions()) override; |
375 | |
376 | c10::intrusive_ptr<Work> send( |
377 | std::vector<at::Tensor>& tensors, |
378 | int dstRank, |
379 | int tag) override; |
380 | |
381 | c10::intrusive_ptr<Work> recv( |
382 | std::vector<at::Tensor>& tensors, |
383 | int srcRank, |
384 | int tag) override; |
385 | |
386 | static void groupStart(); |
387 | |
388 | static void groupEnd(); |
389 | |
390 | // Unsupported Ops |
391 | c10::intrusive_ptr<Work> gather( |
392 | std::vector<std::vector<at::Tensor>>& outputTensors, |
393 | std::vector<at::Tensor>& inputTensors, |
394 | const GatherOptions& opts = GatherOptions()) override; |
395 | |
396 | c10::intrusive_ptr<Work> scatter( |
397 | std::vector<at::Tensor>& outputTensors, |
398 | std::vector<std::vector<at::Tensor>>& inputTensors, |
399 | const ScatterOptions& opts = ScatterOptions()) override; |
400 | |
401 | c10::intrusive_ptr<Work> recvAnysource( |
402 | std::vector<at::Tensor>& tensors, |
403 | int tag) override; |
404 | |
405 | // Agrees on an initial sequence number for the whole group by having rank 0 |
406 | // create it and broadcast it to other ranks using the store. |
407 | void setSequenceNumberForGroup() override; |
408 | |
409 | // Retrieves the current sequence number for the whole group, which should be |
410 | // in sync. If the returned number is not consistent across the group, it |
411 | // may indicate that there is some sort of collective desynchronization. |
412 | uint64_t getSequenceNumberForGroup() override; |
413 | |
414 | // Tests if the UCC fallback path is available |
415 | bool isUCCAvailable() const; |
416 | |
417 | protected: |
418 | // Helper that broadcasts nccl unique ID to all ranks through the store |
419 | void broadcastUniqueNCCLID( |
420 | ncclUniqueId* ncclID, |
421 | bool isSingleP2POp, |
422 | const std::string& devicesKey, |
423 | int p2pRank); |
424 | |
425 | // Helper that either looks up the cached NCCL communicators or creates |
426 | // a new set of NCCL communicators as a cache entry |
427 | std::vector<std::shared_ptr<NCCLComm>>& getNCCLComm( |
428 | const std::string& devicesKey, |
429 | const std::vector<at::Device>& devices, |
430 | OpType opType, |
431 | int p2pRank = 0, |
432 | bool isSendRecvSelf = false); |
433 | |
434 | // Wrapper method which can be overridden for tests. |
435 | virtual std::exception_ptr checkForNCCLErrors( |
436 | const std::vector<std::shared_ptr<NCCLComm>>& ncclComms); |
437 | |
438 | virtual c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> initWork( |
439 | std::vector<at::Device> devices, |
440 | int rank, |
441 | OpType opType, |
442 | const char* profilingTitle=nullptr, |
443 | const c10::optional<std::vector<at::Tensor>>& inputs = c10::nullopt); |
444 | |
445 | virtual c10::intrusive_ptr<ProcessGroupNCCL::CoalescedWorkNCCL> |
446 | initCoalescedWork( |
447 | const std::vector<c10::intrusive_ptr<Work>>& works, |
448 | int rank, |
449 | OpType opType); |
450 | |
451 | private: |
452 | // Helper that encapsulates work shared across all collective communication |
453 | // primitives. The callbacks have the following signatures: |
454 | // |
455 | // ncclResult_t fn(at::Tensor& input, at::Tensor& output, |
456 | // ncclComm_t, at::cuda::CUDAStream&); |
457 | // void {pre,post}(std::vector<at::cuda::CUDAStream&>); |
458 | template <typename Fn> |
459 | c10::intrusive_ptr<Work> collective( |
460 | std::vector<at::Tensor>& input, |
461 | std::vector<at::Tensor>& output, |
462 | Fn fn, |
463 | OpType opType, |
464 | const char* profilingTitle = nullptr); |
465 | template <typename Fn, typename PreProcess, typename PostProcess> |
466 | c10::intrusive_ptr<Work> collective( |
467 | std::vector<at::Tensor>& input, |
468 | std::vector<at::Tensor>& output, |
469 | Fn fn, |
470 | PreProcess pre, |
471 | PostProcess post, |
472 | OpType opType, |
473 | const char* profilingTitle = nullptr); |
474 | |
475 | // Helper that encapsulates work shared across point-to-point communication |
476 | // primitives. It is the same structure as the helper used for collective |
477 | // communicaiton primitives. |
478 | template <typename Fn> |
479 | c10::intrusive_ptr<Work> pointToPoint( |
480 | std::vector<at::Tensor>& tensor, |
481 | Fn fn, |
482 | int peer, |
483 | OpType opType, |
484 | const char* profilingTitle = nullptr); |
485 | template <typename Fn, typename PreProcess, typename PostProcess> |
486 | c10::intrusive_ptr<Work> pointToPoint( |
487 | std::vector<at::Tensor>& tensor, |
488 | Fn fn, |
489 | int peer, |
490 | OpType opType, |
491 | PreProcess pre, |
492 | PostProcess post, |
493 | const char* profilingTitle); |
494 | |
495 | c10::intrusive_ptr<Work> allreduce_impl( |
496 | std::vector<at::Tensor>& tensors, |
497 | const AllreduceOptions& opts = AllreduceOptions()); |
498 | |
499 | // Checks for NCCL errors on each of the communicators and returns an |
500 | // appropriate exception_ptr (nullptr if no errors). |
501 | static std::exception_ptr checkForNCCLErrorsInternal( |
502 | const std::vector<std::shared_ptr<NCCLComm>>& ncclComms); |
503 | |
504 | // Function that runs as part of a separate thread and checks for errors on |
505 | // NCCL communicators. We need a separate thread to check for NCCL errors |
506 | // since we can't rely on the user calling certain methods like wait(), |
507 | // isCompleted() etc. to detect and remediate errors. In addition to this, we |
508 | // need a mechanism to safely abort and remove NCCL communicators from our |
509 | // cache. This can be done cleanly by having a thread for the ProcessGroupNCCL |
510 | // class. Attempting to modify the communicator cache from the WorkNCCL class |
511 | // might run into issues with object lifetime since the ProcessGroupNCCL |
512 | // object might get destroyed before the WorkNCCL object. |
513 | void ncclCommWatchdog(); |
514 | |
515 | void ncclCommWatchdogInternal(); |
516 | |
517 | // This function iterates through the list of WorkNCCL objects in the |
518 | // workList_ corresponding to incomplete collectives and then aborts NCCL |
519 | // communicators associated with timed out collectives. |
520 | void abortTimedOutCollectives( |
521 | std::unordered_set<std::string>& abortedCommIds); |
522 | |
523 | // Performs a health check by initializing dummy NCCL communicators and then |
524 | // destroying them. This will help indicate and signal any NCCL-related issues |
525 | // prior to the first collective. The actual initialization and subsequent |
526 | // destruction is ran on a separate thread and the main thread is signalled |
527 | // about timeouts/errors to report to the application. |
528 | void runHealthCheck(); |
529 | |
530 | // Destroys initialized NCCL communicators in devNCCLComMap_ given by input |
531 | // key. Throws if there are no communicators to destroy. Also removes |
532 | // communicators from the cache and clears used device indices. |
533 | void destroyNCCLComms(const std::string& devNCCLCommMapKey); |
534 | |
535 | void workCleanupLoop(); |
536 | |
537 | protected: |
538 | static const int64_t kWatchdogThreadSleepMillis; |
539 | static const int64_t kWorkCleanupThreadSleepMillis; |
540 | |
541 | // The store is used to broadcast the NCCL unique ID of rank 0. |
542 | c10::intrusive_ptr<Store> store_; |
543 | |
544 | bool storeError_{false}; |
545 | |
546 | const c10::intrusive_ptr<Options> options_; |
547 | |
548 | // The number of NCCL communicators that have been created during |
549 | // the lifetime of this process group. This sequence number is |
550 | // used to scope keys used in the store. |
551 | uint64_t ncclCommCounter_{0}; |
552 | |
553 | // The store keys to trace the last NCCL collective kernel CUDA events - start |
554 | // event and end event respectively. These are used to do desync root cause |
555 | // analysis. |
556 | const std::string traceKeyStart_; |
557 | const std::string traceKeyEnd_; |
558 | |
559 | // The NCCL communicator that the process group has cached. |
560 | // |
561 | // For collective operations: |
562 | // The key is a list of GPU devices that an operation is operating on |
563 | // The GPU devices are stored in a device sequence and the cache NCCL |
564 | // communicator is associated with this GPU device sequence |
565 | // |
566 | // e.g. If the process group op only uses device 0, then the value of |
567 | // the used device string stored (value of the hashmap) would be "0". |
568 | // |
569 | // If the process group op uses device 0 - 7 and the each tensor of the |
570 | // input tensor list is on device, 0, 1, 2, 3, 4, 5, 6, 7 separately, |
571 | // then the value of the used device string (key) stored would be |
572 | // "0,1,2,3,4,5,6,7" |
573 | // |
574 | // If the process group op uses device 0 - 7 and the each tensor of the |
575 | // input tensor list is on device, 0, 4, 5, 6, 7, 1, 2, 3 separately, |
576 | // then the value of the used device string stored would be |
577 | // "0,4,5,6,7,1,2,3" |
578 | // |
579 | // Note that the order of the device for the tensor list matters. |
580 | // |
581 | // For point-to-point operations: |
582 | // The key is a string of my current rank and the peer process rank. |
583 | // e.g. If process 1 and process 2 are involved in a point-to-point |
584 | // communication, the key will be "1:2" on both processes. Note: this is for |
585 | // the scenario where there is only 1 GPU per process. When it comes to |
586 | // multiple GPUs per process, this part may need to redesigned. |
587 | std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLComm>>> |
588 | devNCCLCommMap_; |
589 | |
590 | // Map from ncclUniqueId to appropriate communicator. |
591 | std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLComm>>> |
592 | ncclIdToCommMap_; |
593 | |
594 | // Mutex to guard maps like devNCCLCommMap_ and ncclIdToCommMap_. |
595 | std::mutex mutex_; |
596 | |
597 | // Watchdog thread which looks for errors on the cached NCCL communicators. |
598 | std::thread ncclCommWatchdogThread_; |
599 | |
600 | // Whether or not we should terminate the watchdog and workCleanup threads. |
601 | std::atomic<bool> terminateProcessGroup_; |
602 | |
603 | // Condition variable to control how long the watchdog thread waits. |
604 | std::condition_variable watchdogCV_; |
605 | |
606 | // Mutex for watchdog. |
607 | std::mutex watchdogCVMutex_; |
608 | |
609 | // Thread that removes NCCL Work upon timeout |
610 | std::thread workCleanupThread_; |
611 | |
612 | // Mutex to Guard workMetaList_ |
613 | std::mutex workMetaListMutex_; |
614 | |
615 | // Condition Variable for timeout thread sleep |
616 | std::condition_variable workMetaListCV_; |
617 | |
618 | // Vector to Store WorkNCCL pointers |
619 | std::list<ProcessGroupNCCL::WorkNCCL> workMetaList_; |
620 | |
621 | // Add Work Pointer to workVector |
622 | void workEnqueue(c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>); |
623 | |
624 | // The CUDA steams used by NCCL kernels |
625 | std::unordered_map<std::string, std::vector<at::cuda::CUDAStream>> |
626 | ncclStreams_; |
627 | |
628 | // The CUDA events used to sync NCCL streams |
629 | std::unordered_map<std::string, std::vector<at::cuda::CUDAEvent>> ncclEvents_; |
630 | |
631 | // Device Indexes used for all collectives in this group |
632 | std::set<int> usedDeviceIdxs_; |
633 | |
634 | // Flag to denote if a coalescing groupStart/groupEnd block is active |
635 | bool coalescing_active_ = false; |
636 | |
637 | // Stores device indexes for all collectives run inside a coalescing block |
638 | std::vector<std::vector<at::Device>> coalescedDevices_; |
639 | |
640 | // map from the key: "group name + pg counter (ID)" to the |
641 | // unique NCCL ID count. This needs to be group and pg specific |
642 | // |
643 | // For each process group, we need a uniform unique NCCL ID counter to ensure |
644 | // that NCCL operation in this process group can be completed successfully. |
645 | // Since each process group ID belongs to a group name, the key to this map |
646 | // is a combination of group name and ProcessGroupNCCL ID. |
647 | static std::unordered_map<std::string, ssize_t> pgUniqueNCCLIDCnt_; |
648 | |
649 | // map from group name to the pg counter (ID) within that group |
650 | // |
651 | // For each group with the "group name" (which is the key), we need to |
652 | // keep track of a unique process group ID when creating a new |
653 | // ProcessGroupNCCL for this "group name". Therefore, the value of this |
654 | // map keeps the unique ProcessGroupNCCL's ID for a specific group with |
655 | // the "group name". The reason we need a per-group process group ID counter |
656 | // is that different group can have different ranks and we need ensure that |
657 | // each group has its own uniform process group ID for all its ranks. |
658 | static std::unordered_map<std::string, ssize_t> processGroupCounterMap_; |
659 | |
660 | // Whether or not wait() and synchronize() are blocking operations that wait |
661 | // for the operation to complete. |
662 | bool blockingWait_ = false; |
663 | |
664 | // Whether or not the workCleanupThread is used to perform async error |
665 | // handling. |
666 | ErrorHandlingMode asyncErrorHandling_ = NoHandling; |
667 | |
668 | // Whether or not to enable timeout root cause analysis. |
669 | bool desyncDebug_; |
670 | |
671 | // Set of communicators that this process group has aborted and their |
672 | // ncclUniqueId has been written to the store. We don't need a lock |
673 | // for this map since only the watchdog thread accesses this set. The |
674 | // set contains the string representation of ncclUniqueId. |
675 | std::unordered_set<std::string> abortedComms_; |
676 | |
677 | // The number of active ncclGroupStart() calls. This counter will be increased |
678 | // by 1 when ncclGroupStart() is called and decreased by 1 when ncclGroupEnd() |
679 | // is called. |
680 | static thread_local uint64_t ncclActiveGroupCounter_; |
681 | |
682 | // Counting for the sequential number of NCCL collective call. |
683 | uint64_t seq_{0}; |
684 | |
685 | #ifdef USE_NCCL_WITH_UCC |
686 | // ProcessGroupUCC shared library handle and ProcessGroup pointer |
687 | static std::shared_ptr<at::DynamicLibrary> uccLib_; |
688 | c10::intrusive_ptr<Backend> uccPG_; |
689 | #endif |
690 | }; |
691 | |
692 | } // namespace c10d |
693 | |
694 | #endif // USE_C10D_NCCL |
695 | |