1 | #pragma once |
2 | |
3 | #ifdef USE_C10D_NCCL |
4 | |
5 | #include <stdio.h> |
6 | #include <stdlib.h> |
7 | |
8 | #include <memory> |
9 | #include <mutex> |
10 | |
11 | #include <nccl.h> |
12 | #include <c10/util/Exception.h> |
13 | #include <c10/util/Optional.h> |
14 | |
15 | // ncclGetLastError() is enabled only for NCCL versions 2.13+ |
16 | // ncclRemoteError only exists in NCCL versions 2.13+ |
17 | #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \ |
18 | (NCCL_MINOR >= 13) |
19 | #define ENABLE_NCCL_GET_LAST_ERROR |
20 | #define NCCL_REMOTE_ERROR |
21 | #elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3) |
22 | #define ENABLE_NCCL_GET_LAST_ERROR |
23 | #define NCCL_REMOTE_ERROR |
24 | #endif |
25 | |
26 | // Error checking is enabled only for NCCL versions 2.4+ since ncclCommAbort() |
27 | // and ncclCommGetAsyncError() are not supported in earlier versions. |
28 | #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \ |
29 | (NCCL_MINOR >= 4) |
30 | #define ENABLE_NCCL_ERROR_CHECKING |
31 | #elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3) |
32 | #define ENABLE_NCCL_ERROR_CHECKING |
33 | #endif |
34 | |
35 | // P2P is enabled only for NCCL versions 2.7+ since ncclSend() |
36 | // and ncclRecv() are not supported in earlier versions. |
37 | #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \ |
38 | (NCCL_MINOR >= 7) |
39 | #define ENABLE_NCCL_P2P_SUPPORT |
40 | #elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3) |
41 | #define ENABLE_NCCL_P2P_SUPPORT |
42 | #endif |
43 | |
44 | #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && (NCCL_MINOR >= 11) |
45 | #define ENABLE_NCCL_PREMUL_SUM_SUPPORT |
46 | #elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3) |
47 | #define ENABLE_NCCL_PREMUL_SUM_SUPPORT |
48 | #endif |
49 | |
50 | // Macro to throw on a non-successful NCCL return value. |
51 | #define C10D_NCCL_CHECK(cmd, failureReason) \ |
52 | do { \ |
53 | ncclResult_t result = cmd; \ |
54 | if (result != ncclSuccess) { \ |
55 | std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \ |
56 | std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(result) + \ |
57 | "\n" + getNcclErrorDetailStr(result, failureReason); \ |
58 | TORCH_CHECK_WITH(DistBackendError, false, err); \ |
59 | } \ |
60 | } while (0) |
61 | |
62 | // Macro to print and abort on a non-successful NCCL return value. |
63 | #define C10D_NCCL_ASSERT(cmd) \ |
64 | do { \ |
65 | ncclResult_t result = cmd; \ |
66 | if (result != ncclSuccess) { \ |
67 | std::string err = ncclGetErrorWithVersion(result); \ |
68 | fprintf( \ |
69 | stderr, \ |
70 | "NCCL error in: %s:%d, %s\n", \ |
71 | __FILE__, \ |
72 | __LINE__, \ |
73 | err.c_str()); \ |
74 | abort(); \ |
75 | } \ |
76 | } while (0) |
77 | |
78 | namespace c10d { |
79 | |
80 | std::string getNcclVersion(); |
81 | std::string ncclGetErrorWithVersion(ncclResult_t error); |
82 | |
83 | // Provides additional detail into NCCL error codes based on when these are |
84 | // thrown in the NCCL codebase. |
85 | std::string getNcclErrorDetailStr( |
86 | ncclResult_t error, |
87 | c10::optional<std::string> processGroupFailureReason = c10::nullopt); |
88 | |
89 | // RAII wrapper for NCCL communicator |
90 | class NCCLComm { |
91 | public: |
92 | explicit NCCLComm(ncclComm_t ncclComm) |
93 | : ncclComm_(ncclComm), |
94 | aborted_(false), |
95 | ncclAsyncErr_(ncclSuccess), |
96 | commFailureReason_(c10::nullopt) {} |
97 | |
98 | NCCLComm() : NCCLComm(nullptr) {} |
99 | |
100 | ~NCCLComm() noexcept { |
101 | // Add lock in this destructor, as aborted_ needs to be read after memory |
102 | // barrier here. |
103 | std::unique_lock<std::mutex> lock(mutex_); |
104 | if (ncclComm_ && !aborted_) { |
105 | #ifdef ENABLE_NCCL_ERROR_CHECKING |
106 | // Use ncclCommAbort instead of ncclCommDestroy here since |
107 | // ncclCommDestroy could block forever waiting for work to complete on |
108 | // the communicator. |
109 | C10D_NCCL_ASSERT(::ncclCommAbort(ncclComm_)); |
110 | #else |
111 | C10D_NCCL_ASSERT(::ncclCommDestroy(ncclComm_)); |
112 | #endif |
113 | } |
114 | } |
115 | |
116 | static std::shared_ptr<NCCLComm> create( |
117 | int numRanks, |
118 | int rank, |
119 | ncclUniqueId commId) { |
120 | auto comm = std::make_shared<NCCLComm>(); |
121 | C10D_NCCL_CHECK( |
122 | ncclCommInitRank(&(comm->ncclComm_), numRanks, commId, rank), c10::nullopt); |
123 | comm->ncclId_ = commId; |
124 | comm->rank_ = rank; |
125 | return comm; |
126 | } |
127 | |
128 | ncclUniqueId getNcclId() { |
129 | return ncclId_; |
130 | } |
131 | |
132 | // Must not be copyable |
133 | NCCLComm(const NCCLComm&) = delete; |
134 | NCCLComm& operator=(const NCCLComm&) = delete; |
135 | |
136 | // Do not support move assignment as there is no valid use case |
137 | NCCLComm& operator=(NCCLComm&& other) = delete; |
138 | |
139 | // Move constructable |
140 | NCCLComm(NCCLComm&& other) { |
141 | // Using other's lock, as it reads other's states |
142 | // Can not use this.mutex_, as this object is being constructed. |
143 | std::unique_lock<std::mutex> lock(other.mutex_); |
144 | std::swap(ncclComm_, other.ncclComm_); |
145 | std::swap(aborted_, other.aborted_); |
146 | std::swap(ncclAsyncErr_, other.ncclAsyncErr_); |
147 | } |
148 | |
149 | ncclComm_t getNcclComm(); |
150 | |
151 | c10::optional<std::string> getNcclCommFailureReason() const { |
152 | std::unique_lock<std::mutex> lock(mutex_); |
153 | return commFailureReason_; |
154 | } |
155 | |
156 | void ncclCommAbort( |
157 | c10::optional<std::string> commFailureReason = c10::nullopt) { |
158 | std::unique_lock<std::mutex> lock(mutex_); |
159 | #ifdef ENABLE_NCCL_ERROR_CHECKING |
160 | if (aborted_) { |
161 | // Should not abort twice. |
162 | return; |
163 | } |
164 | |
165 | // Set true failure reason if provided by ProcessGroupNCCL (e.g. work |
166 | // timeout) |
167 | commFailureReason_ = commFailureReason; |
168 | |
169 | C10D_NCCL_CHECK(::ncclCommAbort(ncclComm_), commFailureReason_); |
170 | aborted_ = true; |
171 | ncclComm_ = nullptr; |
172 | |
173 | // Set an appropriate error so that we avoid using the communicator. |
174 | if (ncclAsyncErr_ == ncclSuccess) { |
175 | ncclAsyncErr_ = ncclSystemError; |
176 | } |
177 | #else |
178 | // This is a NOOP, if error checks are disabled. |
179 | return; |
180 | #endif |
181 | } |
182 | |
183 | bool isAborted() const { |
184 | std::unique_lock<std::mutex> lock(mutex_); |
185 | return aborted_; |
186 | } |
187 | |
188 | ncclResult_t checkForNcclError() { |
189 | std::unique_lock<std::mutex> lock(mutex_); |
190 | #ifdef ENABLE_NCCL_ERROR_CHECKING |
191 | if (ncclAsyncErr_ != ncclSuccess) { |
192 | return ncclAsyncErr_; |
193 | } |
194 | C10D_NCCL_CHECK(ncclCommGetAsyncError(ncclComm_, &ncclAsyncErr_), commFailureReason_); |
195 | return ncclAsyncErr_; |
196 | #else |
197 | // Always return success, if error checks are disabled. |
198 | return ncclSuccess; |
199 | #endif |
200 | } |
201 | |
202 | protected: |
203 | ncclComm_t ncclComm_; |
204 | // Unique nccl_id for this communicator. |
205 | ncclUniqueId ncclId_; |
206 | bool aborted_; |
207 | ncclResult_t ncclAsyncErr_; |
208 | mutable std::mutex mutex_; |
209 | // Rank that this communicator corresponds to. |
210 | int rank_; |
211 | // Optional reason for communicator failure, provided by ProcessGroupNCCL for |
212 | // better error messaging. |
213 | c10::optional<std::string> commFailureReason_; |
214 | }; |
215 | |
216 | // Helper that automatically cleans up premul sums. |
217 | struct ncclRedOpRAII { |
218 | ncclRedOpRAII() = default; |
219 | ncclRedOpRAII(ncclRedOp_t op) : op_(op) {} |
220 | ncclRedOpRAII(ncclRedOp_t op, ncclComm_t comm) : |
221 | op_(op), comm_(comm), premul_sum_(true) {} |
222 | ncclRedOpRAII(const ncclRedOpRAII&) = delete; |
223 | ncclRedOpRAII& operator=(const ncclRedOpRAII&) = delete; |
224 | ncclRedOpRAII(ncclRedOpRAII&& tmp) : ncclRedOpRAII() { |
225 | std::swap(tmp.op_, this->op_); |
226 | std::swap(tmp.comm_, this->comm_); |
227 | std::swap(tmp.premul_sum_, this->premul_sum_); |
228 | } |
229 | #if defined(ENABLE_NCCL_PREMUL_SUM_SUPPORT) |
230 | ~ncclRedOpRAII() { |
231 | if (premul_sum_) { |
232 | ncclRedOpDestroy(op_, comm_); |
233 | } |
234 | } |
235 | #endif |
236 | operator ncclRedOp_t() const { return op_; } |
237 | ncclRedOp_t op_; |
238 | ncclComm_t comm_; |
239 | bool premul_sum_ = false; |
240 | }; |
241 | |
242 | |
243 | } // namespace c10d |
244 | |
245 | #endif // USE_C10D_NCCL |
246 | |