1#pragma once
2
3#ifdef USE_C10D_NCCL
4
5#include <stdio.h>
6#include <stdlib.h>
7
8#include <memory>
9#include <mutex>
10
11#include <nccl.h>
12#include <c10/util/Exception.h>
13#include <c10/util/Optional.h>
14
15// ncclGetLastError() is enabled only for NCCL versions 2.13+
16// ncclRemoteError only exists in NCCL versions 2.13+
17#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
18 (NCCL_MINOR >= 13)
19#define ENABLE_NCCL_GET_LAST_ERROR
20#define NCCL_REMOTE_ERROR
21#elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3)
22#define ENABLE_NCCL_GET_LAST_ERROR
23#define NCCL_REMOTE_ERROR
24#endif
25
26// Error checking is enabled only for NCCL versions 2.4+ since ncclCommAbort()
27// and ncclCommGetAsyncError() are not supported in earlier versions.
28#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
29 (NCCL_MINOR >= 4)
30#define ENABLE_NCCL_ERROR_CHECKING
31#elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3)
32#define ENABLE_NCCL_ERROR_CHECKING
33#endif
34
35// P2P is enabled only for NCCL versions 2.7+ since ncclSend()
36// and ncclRecv() are not supported in earlier versions.
37#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
38 (NCCL_MINOR >= 7)
39#define ENABLE_NCCL_P2P_SUPPORT
40#elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3)
41#define ENABLE_NCCL_P2P_SUPPORT
42#endif
43
44#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && (NCCL_MINOR >= 11)
45#define ENABLE_NCCL_PREMUL_SUM_SUPPORT
46#elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3)
47#define ENABLE_NCCL_PREMUL_SUM_SUPPORT
48#endif
49
50// Macro to throw on a non-successful NCCL return value.
51#define C10D_NCCL_CHECK(cmd, failureReason) \
52 do { \
53 ncclResult_t result = cmd; \
54 if (result != ncclSuccess) { \
55 std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \
56 std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(result) + \
57 "\n" + getNcclErrorDetailStr(result, failureReason); \
58 TORCH_CHECK_WITH(DistBackendError, false, err); \
59 } \
60 } while (0)
61
62// Macro to print and abort on a non-successful NCCL return value.
63#define C10D_NCCL_ASSERT(cmd) \
64 do { \
65 ncclResult_t result = cmd; \
66 if (result != ncclSuccess) { \
67 std::string err = ncclGetErrorWithVersion(result); \
68 fprintf( \
69 stderr, \
70 "NCCL error in: %s:%d, %s\n", \
71 __FILE__, \
72 __LINE__, \
73 err.c_str()); \
74 abort(); \
75 } \
76 } while (0)
77
78namespace c10d {
79
80std::string getNcclVersion();
81std::string ncclGetErrorWithVersion(ncclResult_t error);
82
83// Provides additional detail into NCCL error codes based on when these are
84// thrown in the NCCL codebase.
85std::string getNcclErrorDetailStr(
86 ncclResult_t error,
87 c10::optional<std::string> processGroupFailureReason = c10::nullopt);
88
89// RAII wrapper for NCCL communicator
90class NCCLComm {
91 public:
92 explicit NCCLComm(ncclComm_t ncclComm)
93 : ncclComm_(ncclComm),
94 aborted_(false),
95 ncclAsyncErr_(ncclSuccess),
96 commFailureReason_(c10::nullopt) {}
97
98 NCCLComm() : NCCLComm(nullptr) {}
99
100 ~NCCLComm() noexcept {
101 // Add lock in this destructor, as aborted_ needs to be read after memory
102 // barrier here.
103 std::unique_lock<std::mutex> lock(mutex_);
104 if (ncclComm_ && !aborted_) {
105#ifdef ENABLE_NCCL_ERROR_CHECKING
106 // Use ncclCommAbort instead of ncclCommDestroy here since
107 // ncclCommDestroy could block forever waiting for work to complete on
108 // the communicator.
109 C10D_NCCL_ASSERT(::ncclCommAbort(ncclComm_));
110#else
111 C10D_NCCL_ASSERT(::ncclCommDestroy(ncclComm_));
112#endif
113 }
114 }
115
116 static std::shared_ptr<NCCLComm> create(
117 int numRanks,
118 int rank,
119 ncclUniqueId commId) {
120 auto comm = std::make_shared<NCCLComm>();
121 C10D_NCCL_CHECK(
122 ncclCommInitRank(&(comm->ncclComm_), numRanks, commId, rank), c10::nullopt);
123 comm->ncclId_ = commId;
124 comm->rank_ = rank;
125 return comm;
126 }
127
128 ncclUniqueId getNcclId() {
129 return ncclId_;
130 }
131
132 // Must not be copyable
133 NCCLComm(const NCCLComm&) = delete;
134 NCCLComm& operator=(const NCCLComm&) = delete;
135
136 // Do not support move assignment as there is no valid use case
137 NCCLComm& operator=(NCCLComm&& other) = delete;
138
139 // Move constructable
140 NCCLComm(NCCLComm&& other) {
141 // Using other's lock, as it reads other's states
142 // Can not use this.mutex_, as this object is being constructed.
143 std::unique_lock<std::mutex> lock(other.mutex_);
144 std::swap(ncclComm_, other.ncclComm_);
145 std::swap(aborted_, other.aborted_);
146 std::swap(ncclAsyncErr_, other.ncclAsyncErr_);
147 }
148
149 ncclComm_t getNcclComm();
150
151 c10::optional<std::string> getNcclCommFailureReason() const {
152 std::unique_lock<std::mutex> lock(mutex_);
153 return commFailureReason_;
154 }
155
156 void ncclCommAbort(
157 c10::optional<std::string> commFailureReason = c10::nullopt) {
158 std::unique_lock<std::mutex> lock(mutex_);
159#ifdef ENABLE_NCCL_ERROR_CHECKING
160 if (aborted_) {
161 // Should not abort twice.
162 return;
163 }
164
165 // Set true failure reason if provided by ProcessGroupNCCL (e.g. work
166 // timeout)
167 commFailureReason_ = commFailureReason;
168
169 C10D_NCCL_CHECK(::ncclCommAbort(ncclComm_), commFailureReason_);
170 aborted_ = true;
171 ncclComm_ = nullptr;
172
173 // Set an appropriate error so that we avoid using the communicator.
174 if (ncclAsyncErr_ == ncclSuccess) {
175 ncclAsyncErr_ = ncclSystemError;
176 }
177#else
178 // This is a NOOP, if error checks are disabled.
179 return;
180#endif
181 }
182
183 bool isAborted() const {
184 std::unique_lock<std::mutex> lock(mutex_);
185 return aborted_;
186 }
187
188 ncclResult_t checkForNcclError() {
189 std::unique_lock<std::mutex> lock(mutex_);
190#ifdef ENABLE_NCCL_ERROR_CHECKING
191 if (ncclAsyncErr_ != ncclSuccess) {
192 return ncclAsyncErr_;
193 }
194 C10D_NCCL_CHECK(ncclCommGetAsyncError(ncclComm_, &ncclAsyncErr_), commFailureReason_);
195 return ncclAsyncErr_;
196#else
197 // Always return success, if error checks are disabled.
198 return ncclSuccess;
199#endif
200 }
201
202 protected:
203 ncclComm_t ncclComm_;
204 // Unique nccl_id for this communicator.
205 ncclUniqueId ncclId_;
206 bool aborted_;
207 ncclResult_t ncclAsyncErr_;
208 mutable std::mutex mutex_;
209 // Rank that this communicator corresponds to.
210 int rank_;
211 // Optional reason for communicator failure, provided by ProcessGroupNCCL for
212 // better error messaging.
213 c10::optional<std::string> commFailureReason_;
214};
215
216// Helper that automatically cleans up premul sums.
217struct ncclRedOpRAII {
218 ncclRedOpRAII() = default;
219 ncclRedOpRAII(ncclRedOp_t op) : op_(op) {}
220 ncclRedOpRAII(ncclRedOp_t op, ncclComm_t comm) :
221 op_(op), comm_(comm), premul_sum_(true) {}
222 ncclRedOpRAII(const ncclRedOpRAII&) = delete;
223 ncclRedOpRAII& operator=(const ncclRedOpRAII&) = delete;
224 ncclRedOpRAII(ncclRedOpRAII&& tmp) : ncclRedOpRAII() {
225 std::swap(tmp.op_, this->op_);
226 std::swap(tmp.comm_, this->comm_);
227 std::swap(tmp.premul_sum_, this->premul_sum_);
228 }
229#if defined(ENABLE_NCCL_PREMUL_SUM_SUPPORT)
230 ~ncclRedOpRAII() {
231 if (premul_sum_) {
232 ncclRedOpDestroy(op_, comm_);
233 }
234 }
235#endif
236 operator ncclRedOp_t() const { return op_; }
237 ncclRedOp_t op_;
238 ncclComm_t comm_;
239 bool premul_sum_ = false;
240};
241
242
243} // namespace c10d
244
245#endif // USE_C10D_NCCL
246