1 | #include <c10/util/Exception.h> |
2 | #include <torch/csrc/distributed/c10d/ProcessGroupGloo.hpp> |
3 | |
4 | #ifdef USE_C10D_GLOO |
5 | |
6 | #include <torch/csrc/distributed/c10d/GlooDeviceFactory.hpp> |
7 | #include <torch/csrc/distributed/c10d/PrefixStore.hpp> |
8 | #include <chrono> |
9 | #include <exception> |
10 | #include <ratio> |
11 | #include <tuple> |
12 | |
13 | #ifdef _WIN32 |
14 | #include <gloo/common/win.h> |
15 | #include <winsock2.h> |
16 | #include <ws2tcpip.h> |
17 | #else |
18 | #include <netdb.h> |
19 | #include <sys/socket.h> |
20 | #include <unistd.h> |
21 | #endif |
22 | #include <sys/types.h> |
23 | |
24 | #include <type_traits> |
25 | |
26 | #include <gloo/allgather.h> |
27 | #include <gloo/allgatherv.h> |
28 | #include <gloo/allreduce.h> |
29 | #include <gloo/alltoall.h> |
30 | #include <gloo/alltoallv.h> |
31 | #include <gloo/barrier.h> |
32 | #include <gloo/broadcast.h> |
33 | #include <gloo/gather.h> |
34 | #include <gloo/reduce.h> |
35 | #include <gloo/scatter.h> |
36 | |
37 | #include <ATen/SparseTensorUtils.h> |
38 | #include <ATen/ThreadLocalState.h> |
39 | |
40 | #include <c10/util/StringUtil.h> |
41 | #include <c10/util/intrusive_ptr.h> |
42 | #include <c10/util/irange.h> |
43 | #include <gloo/config.h> |
44 | #include <gloo/rendezvous/context.h> |
45 | #include <gloo/rendezvous/prefix_store.h> |
46 | |
47 | #ifdef _WIN32 |
48 | #define GENERATE_ALL_TYPES(type, func, ...) \ |
49 | switch (type) { \ |
50 | case ::at::ScalarType::Float: \ |
51 | func<float>(__VA_ARGS__); \ |
52 | break; \ |
53 | case ::at::ScalarType::Double: \ |
54 | func<double>(__VA_ARGS__); \ |
55 | break; \ |
56 | case ::at::ScalarType::Half: \ |
57 | func<gloo::float16>(__VA_ARGS__); \ |
58 | break; \ |
59 | case ::at::ScalarType::Char: \ |
60 | func<int8_t>(__VA_ARGS__); \ |
61 | break; \ |
62 | case ::at::ScalarType::Byte: \ |
63 | func<uint8_t>(__VA_ARGS__); \ |
64 | break; \ |
65 | case ::at::ScalarType::Int: \ |
66 | func<int32_t>(__VA_ARGS__); \ |
67 | break; \ |
68 | case ::at::ScalarType::Long: \ |
69 | func<int64_t>(__VA_ARGS__); \ |
70 | break; \ |
71 | default: \ |
72 | TORCH_CHECK(false, "Invalid scalar type"); \ |
73 | } |
74 | |
75 | #define HOST_NAME_MAX 256 |
76 | #else |
77 | #define GENERATE_ALL_TYPES(type, func, args...) \ |
78 | switch (type) { \ |
79 | case ::at::ScalarType::Float: \ |
80 | func<float>(args); \ |
81 | break; \ |
82 | case ::at::ScalarType::Double: \ |
83 | func<double>(args); \ |
84 | break; \ |
85 | case ::at::ScalarType::Half: \ |
86 | func<gloo::float16>(args); \ |
87 | break; \ |
88 | case ::at::ScalarType::Char: \ |
89 | func<int8_t>(args); \ |
90 | break; \ |
91 | case ::at::ScalarType::Byte: \ |
92 | func<uint8_t>(args); \ |
93 | break; \ |
94 | case ::at::ScalarType::Int: \ |
95 | func<int32_t>(args); \ |
96 | break; \ |
97 | case ::at::ScalarType::Long: \ |
98 | func<int64_t>(args); \ |
99 | break; \ |
100 | default: \ |
101 | TORCH_CHECK(false, "Invalid scalar type"); \ |
102 | } |
103 | #endif |
104 | |
105 | namespace c10d { |
106 | |
107 | namespace { |
108 | |
109 | constexpr int kBytes = 8; |
110 | |
111 | using steady_clock_time_point = |
112 | std::chrono::time_point<std::chrono::steady_clock>; |
113 | |
114 | std::chrono::milliseconds getRemainingTime( |
115 | steady_clock_time_point startTime, |
116 | const std::chrono::milliseconds& timeout, |
117 | bool waitAllRanks) { |
118 | if (waitAllRanks) { |
119 | // See Note in monitoredBarrier |
120 | return timeout; |
121 | } |
122 | auto elapsedTime = std::chrono::steady_clock::now() - startTime; |
123 | auto remainingMillis = timeout - |
124 | std::chrono::duration_cast<std::chrono::milliseconds>(elapsedTime); |
125 | |
126 | // If no more remaining time, return -1 to indicate to caller. |
127 | if (remainingMillis.count() <= 0) { |
128 | return std::chrono::milliseconds(-1); |
129 | } |
130 | |
131 | return remainingMillis; |
132 | } |
133 | |
134 | // Emit a LOG(ERROR) and throws using TORCH_CHECK with the given messages. |
135 | void logAndThrow( |
136 | const std::string& logMessage, |
137 | const std::string& errorMessage) { |
138 | LOG(ERROR) << logMessage; |
139 | TORCH_CHECK(false, errorMessage); |
140 | } |
141 | |
142 | // For monitoredBarrier, checks remaining time left to finish processing ranks |
143 | // and throws error if timeout. |
144 | void checkRemainingTime( |
145 | const std::chrono::milliseconds& monitoredBarrierTimeout, |
146 | const std::chrono::milliseconds& remainingTime, |
147 | const std::vector<int>& processedRanks, |
148 | int currentRank) { |
149 | const std::string kNoRemainingTimeError = c10::str( |
150 | "Rank " , |
151 | currentRank, |
152 | " timed out in monitoredBarrier after " , |
153 | monitoredBarrierTimeout.count(), |
154 | " ms." ); |
155 | if (remainingTime.count() < 0) { |
156 | std::string rankInfo; |
157 | if (!processedRanks.empty()) { |
158 | rankInfo = c10::str( |
159 | "Successfully processed ranks: " , c10::Join(", " , processedRanks)); |
160 | } else { |
161 | rankInfo = "No ranks successfully processed in monitoredBarrier." ; |
162 | } |
163 | auto error = c10::str(kNoRemainingTimeError, "\n" , rankInfo); |
164 | logAndThrow(error, error); |
165 | } |
166 | } |
167 | |
168 | typedef void (*ReduceFunc)(void*, const void*, const void*, size_t); |
169 | |
170 | template < |
171 | typename T, |
172 | typename std::enable_if<!std::is_integral<T>::value, int>::type = 0> |
173 | ReduceFunc toFunction(const ReduceOp& r) { |
174 | switch (r) { |
175 | case ReduceOp::SUM: |
176 | return ReduceFunc(&::gloo::sum<T>); |
177 | case ReduceOp::PRODUCT: |
178 | return ReduceFunc(&::gloo::product<T>); |
179 | case ReduceOp::MIN: |
180 | return ReduceFunc(&::gloo::min<T>); |
181 | case ReduceOp::MAX: |
182 | return ReduceFunc(&::gloo::max<T>); |
183 | case ReduceOp::BAND: |
184 | TORCH_CHECK(false, "Cannot use ReduceOp.BAND with non-integral dtype" ); |
185 | break; |
186 | case ReduceOp::BOR: |
187 | TORCH_CHECK(false, "Cannot use ReduceOp.BOR with non-integral dtype" ); |
188 | break; |
189 | case ReduceOp::BXOR: |
190 | TORCH_CHECK(false, "Cannot use ReduceOp.BXOR with non-integral dtype" ); |
191 | break; |
192 | case ReduceOp::AVG: |
193 | TORCH_CHECK(false, "Cannot use ReduceOp.AVG with Gloo" ); |
194 | break; |
195 | case ReduceOp::PREMUL_SUM: |
196 | TORCH_CHECK(false, "Cannot use ReduceOp.PREMUL_SUM with Gloo" ); |
197 | break; |
198 | case ReduceOp::UNUSED: |
199 | break; |
200 | } |
201 | |
202 | TORCH_CHECK(false, "Unhandled ReduceOp" ); |
203 | } |
204 | |
205 | // Bitwise AND with SFINAE guard for integral types. |
206 | template < |
207 | typename T, |
208 | typename std::enable_if<std::is_integral<T>::value, int>::type = 0> |
209 | void band(void* c, const void* a, const void* b, size_t n) { |
210 | auto tc = static_cast<T*>(c); |
211 | auto ta = static_cast<const T*>(a); |
212 | auto tb = static_cast<const T*>(b); |
213 | for (const auto i : c10::irange(n)) { |
214 | tc[i] = ta[i] & tb[i]; |
215 | } |
216 | } |
217 | |
218 | // Bitwise OR with SFINAE guard for integral types. |
219 | template < |
220 | typename T, |
221 | typename std::enable_if<std::is_integral<T>::value, int>::type = 0> |
222 | void bor(void* c, const void* a, const void* b, size_t n) { |
223 | auto tc = static_cast<T*>(c); |
224 | auto ta = static_cast<const T*>(a); |
225 | auto tb = static_cast<const T*>(b); |
226 | for (const auto i : c10::irange(n)) { |
227 | tc[i] = ta[i] | tb[i]; |
228 | } |
229 | } |
230 | |
231 | // Bitwise XOR with SFINAE guard for integral types. |
232 | template < |
233 | typename T, |
234 | typename std::enable_if<std::is_integral<T>::value, int>::type = 0> |
235 | void bxor(void* c, const void* a, const void* b, size_t n) { |
236 | auto tc = static_cast<T*>(c); |
237 | auto ta = static_cast<const T*>(a); |
238 | auto tb = static_cast<const T*>(b); |
239 | for (const auto i : c10::irange(n)) { |
240 | tc[i] = ta[i] ^ tb[i]; |
241 | } |
242 | } |
243 | |
244 | template < |
245 | typename T, |
246 | typename std::enable_if<std::is_integral<T>::value, int>::type = 0> |
247 | ReduceFunc toFunction(const ReduceOp& r) { |
248 | switch (r) { |
249 | case ReduceOp::SUM: |
250 | return ReduceFunc(&::gloo::sum<T>); |
251 | case ReduceOp::PRODUCT: |
252 | return ReduceFunc(&::gloo::product<T>); |
253 | case ReduceOp::MIN: |
254 | return ReduceFunc(&::gloo::min<T>); |
255 | case ReduceOp::MAX: |
256 | return ReduceFunc(&::gloo::max<T>); |
257 | case ReduceOp::BAND: |
258 | return ReduceFunc(&band<T>); |
259 | case ReduceOp::BOR: |
260 | return ReduceFunc(&bor<T>); |
261 | case ReduceOp::BXOR: |
262 | return ReduceFunc(&bxor<T>); |
263 | case ReduceOp::AVG: |
264 | TORCH_CHECK(false, "Cannot use ReduceOp.AVG with Gloo" ); |
265 | break; |
266 | case ReduceOp::PREMUL_SUM: |
267 | TORCH_CHECK(false, "Cannot use ReduceOp.PREMUL_SUM with Gloo" ); |
268 | break; |
269 | case ReduceOp::UNUSED: |
270 | break; |
271 | } |
272 | |
273 | TORCH_CHECK(false, "Unhandled ReduceOp" ); |
274 | } |
275 | |
276 | template <typename T, typename O> |
277 | void setInputs(O& opts, std::vector<at::Tensor>& tensors) { |
278 | opts.setInputs(getDataPointers<T>(tensors), tensors[0].numel()); |
279 | } |
280 | |
281 | template <typename T, typename O> |
282 | void setInput(O& opts, at::Tensor& tensor) { |
283 | opts.setInput(getDataPointer<T>(tensor), tensor.numel()); |
284 | } |
285 | |
286 | template <typename T, typename O> |
287 | void setInput(O& opts, at::Tensor& tensor, std::vector<size_t>& counts) { |
288 | opts.setInput(getDataPointer<T>(tensor), counts); |
289 | } |
290 | |
291 | template <typename T, typename O> |
292 | void setInput(O& opts, at::Tensor& tensor, std::vector<int64_t>& counts) { |
293 | opts.setInput(getDataPointer<T>(tensor), counts); |
294 | } |
295 | |
296 | template <typename T, typename O> |
297 | void setOutputs(O& opts, std::vector<at::Tensor>& tensors) { |
298 | opts.setOutputs(getDataPointers<T>(tensors), tensors[0].numel()); |
299 | } |
300 | |
301 | template <typename T, typename O> |
302 | void setOutput(O& opts, at::Tensor& tensor) { |
303 | opts.setOutput(getDataPointer<T>(tensor), tensor.numel()); |
304 | } |
305 | |
306 | template <typename T, typename O> |
307 | void setOutput(O& opts, at::Tensor& tensor, std::vector<size_t>& counts) { |
308 | opts.setOutput(getDataPointer<T>(tensor), counts); |
309 | } |
310 | |
311 | template <typename T, typename O> |
312 | void setOutput(O& opts, at::Tensor& tensor, std::vector<int64_t>& counts) { |
313 | opts.setOutput(getDataPointer<T>(tensor), counts); |
314 | } |
315 | |
316 | at::Tensor pinnedLike(at::Tensor& tensor) { |
317 | auto* allocator = at::detail::getCUDAHooks().getPinnedMemoryAllocator(); |
318 | auto storage = c10::Storage( |
319 | c10::Storage::use_byte_size_t(), |
320 | at::detail::computeStorageNbytes( |
321 | tensor.sizes(), tensor.strides(), tensor.dtype().itemsize()), |
322 | allocator, |
323 | /*resizable=*/false); |
324 | return at::empty({0}, tensor.options().device(at::kCPU)) |
325 | .set_(storage, 0, tensor.sizes(), tensor.strides()); |
326 | } |
327 | |
328 | // This function initializes a vector of CUDA streams, one for every |
329 | // tensor in the input tensor vector, and ensures that these streams are |
330 | // synchronized with the current default streams. This is needed so |
331 | // that new work on the new streams is serialized w.r.t. all operations |
332 | // on the tensors. |
333 | void initializeStreamsEvents( |
334 | const std::vector<at::Tensor>& tensors, |
335 | std::vector<c10::Stream>& streams, |
336 | std::vector<c10::Event>& events) { |
337 | streams.reserve(tensors.size()); |
338 | events.reserve(tensors.size()); |
339 | for (const auto i : c10::irange(tensors.size())) { |
340 | c10::Device device = tensors[i].device(); |
341 | c10::impl::VirtualGuardImpl impl(device.type()); |
342 | // Record event on current stream |
343 | events.emplace_back(device.type()); |
344 | events[i].record(impl.getStream(device)); |
345 | // Get a non-default stream to execute asynchronous CUDA operations |
346 | // on for this device. This ensures that the default stream used |
347 | // by the caller is not occupied by c10d related operations. |
348 | streams.push_back( |
349 | impl.getStreamFromGlobalPool(device, /*isHighPriority=*/true)); |
350 | // Ensure the new stream is synchronized with the current stream. |
351 | events[i].block(streams[i]); |
352 | |
353 | // `tensors` are created on a different stream. Hence, they must record |
354 | // new streams in this Work to prevent being freed before the Work finishes. |
355 | if (tensors[i].is_sparse()) { |
356 | if (tensors[i].is_coalesced()) { |
357 | impl.recordDataPtrOnStream( |
358 | tensors[i].indices().storage().data_ptr(), streams[i]); |
359 | impl.recordDataPtrOnStream( |
360 | tensors[i].values().storage().data_ptr(), streams[i]); |
361 | } else { |
362 | // We will need to coalesce first, which means new tensors will |
363 | // be allocated on the streams we just allocated, and there |
364 | // is no need to record them separately. |
365 | } |
366 | } else { |
367 | impl.recordDataPtrOnStream(tensors[i].storage().data_ptr(), streams[i]); |
368 | } |
369 | } |
370 | } |
371 | |
372 | // This function initializes a vector of CUDA streams, one per device, |
373 | // and ensures that these streams are synchronized with the current default |
374 | // streams. It is assumed that the tensors in the nested tensor vectors are |
375 | // on the same device. |
376 | void initializeStreamsEvents( |
377 | std::vector<std::vector<at::Tensor>>& tensors, |
378 | std::vector<c10::Stream>& streams, |
379 | std::vector<c10::Event>& events) { |
380 | // Ensure that the tensors in the nested tensor vectors are on the same |
381 | // device. |
382 | for (const auto& tensorgroup : tensors) { |
383 | const auto device_id = tensorgroup[0].device().index(); |
384 | for (const auto& tensor : tensorgroup) { |
385 | if (tensor.device().index() != device_id) { |
386 | TORCH_CHECK( |
387 | false, |
388 | "tensors in the nested tensor vectors need to " |
389 | "be on the same device" ); |
390 | } |
391 | } |
392 | } |
393 | |
394 | streams.reserve(tensors.size()); |
395 | events.reserve(tensors.size()); |
396 | for (const auto i : c10::irange(tensors.size())) { |
397 | c10::Device device = tensors[i][0].device(); |
398 | c10::impl::VirtualGuardImpl impl(device.type()); |
399 | // Record event on current stream |
400 | events.emplace_back(device.type()); |
401 | events[i].record(impl.getStream(device)); |
402 | // Get a non-default stream to execute asynchronous CUDA operations |
403 | // on for this output. This ensures that the default stream used |
404 | // by the caller is not occupied by c10d related operations. |
405 | streams.push_back( |
406 | impl.getStreamFromGlobalPool(device, /*isHighPriority=*/true)); |
407 | // Ensure the new stream is synchronized with the current stream. |
408 | events[i].block(streams[i]); |
409 | |
410 | for (at::Tensor& tensor : tensors[i]) { |
411 | // `tensors` are created on a different stream. Hence, they must record |
412 | // new streams in this Work to prevent being freed before the Work |
413 | // finishes. |
414 | impl.recordDataPtrOnStream(tensor.storage().data_ptr(), streams[i]); |
415 | } |
416 | } |
417 | } |
418 | |
419 | const auto kLoopbackAddress = "127.0.0.1" ; |
420 | |
421 | } // namespace |
422 | |
423 | // static |
424 | void ProcessGroupGloo::AsyncWork::execute(c10::intrusive_ptr<AsyncWork> work) { |
425 | if (work->recordFunctionBeforeCallback_) { |
426 | work->recordFunctionBeforeCallback_(); |
427 | } |
428 | try { |
429 | work->run(); |
430 | } catch (...) { |
431 | work->finishWorkGlooError(std::current_exception()); |
432 | return; |
433 | } |
434 | |
435 | // FIXME: We need to call it here since Future completion requires all |
436 | // the work to be synchronized to CUDA. |
437 | work->synchronize(); |
438 | work->finishWorkGloo(); |
439 | } |
440 | |
441 | std::vector<at::Tensor> ProcessGroupGloo::AsyncWork::result() { |
442 | TORCH_CHECK( |
443 | isCompleted(), |
444 | "Work needs to be completed before calling result(). " |
445 | "Should call wait() before result()." ); |
446 | TORCH_CHECK( |
447 | outputTensors_.size() <= 1, |
448 | "work result does not support list of lists, use .getFuture() and value()" ); |
449 | return outputTensors_.empty() ? std::vector<at::Tensor>() |
450 | : outputTensors_.at(0); |
451 | } |
452 | |
453 | c10::intrusive_ptr<c10::ivalue::Future> ProcessGroupGloo::AsyncWork:: |
454 | getFuture() { |
455 | return future_; |
456 | } |
457 | |
458 | namespace { |
459 | c10::intrusive_ptr<c10::ivalue::Future> createFutureAsOutput( |
460 | const std::vector<std::vector<at::Tensor>>& outputTensors) { |
461 | if (outputTensors.size() > 1) { |
462 | return c10::make_intrusive<c10::ivalue::Future>( |
463 | c10::ListType::create(c10::ListType::create(c10::TensorType::get()))); |
464 | } |
465 | return c10::make_intrusive<c10::ivalue::Future>( |
466 | c10::ListType::create(c10::TensorType::get())); |
467 | } |
468 | |
469 | void returnFutureWithOutput( |
470 | c10::intrusive_ptr<c10::ivalue::Future>& future, |
471 | const std::vector<std::vector<at::Tensor>>& outputTensors) { |
472 | if (outputTensors.empty()) { |
473 | future->markCompleted(c10::IValue(std::vector<at::Tensor>())); |
474 | return; |
475 | } |
476 | if (outputTensors.size() > 1) { |
477 | future->markCompleted(c10::IValue(outputTensors)); |
478 | return; |
479 | } |
480 | future->markCompleted(c10::IValue(outputTensors[0])); |
481 | } |
482 | } // namespace |
483 | |
484 | inline void ProcessGroupGloo::AsyncWork::recordAsyncWorkProfilingInfo( |
485 | const char* profilingTitle, |
486 | const c10::optional<std::vector<at::Tensor>>& inputTensors) { |
487 | auto recordingFunction = |
488 | std::make_shared<at::RecordFunction>(at::RecordScope::USER_SCOPE); |
489 | if (recordingFunction->isActive()) { |
490 | std::function<void()> before_handler = |
491 | [inputTensors, profilingTitle, recordingFunction]() { |
492 | // The work will be started and completed by different threads. |
493 | recordingFunction->_setAsync(); |
494 | std::vector<c10::IValue> inputs; |
495 | if (inputTensors) { |
496 | inputs.reserve(inputTensors->size()); |
497 | for (const auto& tensor : *inputTensors) { |
498 | inputs.emplace_back(tensor); |
499 | } |
500 | } |
501 | recordingFunction->before( |
502 | profilingTitle, |
503 | c10::ArrayRef<const c10::IValue>(inputs.data(), inputs.size())); |
504 | }; |
505 | recordFunctionBeforeCallback_ = at::wrapPropagateTLSState(before_handler); |
506 | std::function<void()> end_handler = [recordingFunction]() { |
507 | recordingFunction->end(); |
508 | }; |
509 | recordFunctionEndCallback_ = at::wrapPropagateTLSState(end_handler); |
510 | } |
511 | } |
512 | |
513 | ProcessGroupGloo::AsyncWork::AsyncWork( |
514 | std::vector<std::vector<at::Tensor>> outputTensors, |
515 | const char* profilingTitle, |
516 | const c10::optional<std::vector<at::Tensor>>& inputTensors) |
517 | // Profiler: Pass nullptr as profilingTitle to parent constructor to |
518 | // replace default profiler implementation with async version that reports |
519 | // correct timestamps for work that is asynchronously executed. |
520 | : Work(-1, OpType::UNKNOWN, nullptr, inputTensors), |
521 | outputTensors_(std::move(outputTensors)), |
522 | future_(createFutureAsOutput(outputTensors_)) { |
523 | if (profilingTitle != nullptr) { |
524 | recordAsyncWorkProfilingInfo(profilingTitle, inputTensors); |
525 | } |
526 | } |
527 | |
528 | void ProcessGroupGloo::AsyncWork::finishWorkGlooError(std::exception_ptr eptr) { |
529 | future_->setError(eptr); |
530 | finish(eptr); |
531 | } |
532 | |
533 | void ProcessGroupGloo::AsyncWork::finishWorkGloo() { |
534 | returnFutureWithOutput(future_, outputTensors_); |
535 | finish(); |
536 | } |
537 | |
538 | ProcessGroupGloo::SendWork::SendWork( |
539 | at::Tensor& tensor, |
540 | std::unique_ptr<::gloo::transport::UnboundBuffer> buffer) |
541 | : Work( |
542 | -1, |
543 | OpType::SEND, |
544 | "gloo:send" , |
545 | c10::optional<std::vector<at::Tensor>>({tensor})), |
546 | tensor_(tensor), |
547 | buffer_(std::move(buffer)) {} |
548 | |
549 | bool ProcessGroupGloo::SendWork::wait(std::chrono::milliseconds timeout) { |
550 | bool sendCompleted = false; |
551 | std::exception_ptr exception{nullptr}; |
552 | try { |
553 | if (timeout == kNoTimeout) { |
554 | sendCompleted = buffer_->waitSend(); |
555 | } else { |
556 | sendCompleted = buffer_->waitSend(timeout); |
557 | } |
558 | } catch (...) { |
559 | exception = std::current_exception(); |
560 | } |
561 | |
562 | // Completes the Work object and throws the exception. |
563 | finishAndThrow(exception); |
564 | return sendCompleted; |
565 | } |
566 | |
567 | void ProcessGroupGloo::SendWork::abort() { |
568 | buffer_->abortWaitSend(); |
569 | } |
570 | |
571 | ProcessGroupGloo::RecvWork::RecvWork( |
572 | at::Tensor& tensor, |
573 | std::unique_ptr<::gloo::transport::UnboundBuffer> buffer, |
574 | const char* profilingTitle) |
575 | : Work( |
576 | -1, |
577 | OpType::UNKNOWN, |
578 | profilingTitle, |
579 | c10::optional<std::vector<at::Tensor>>({tensor})), |
580 | tensor_(tensor), |
581 | buffer_(std::move(buffer)), |
582 | srcRank_(-1) {} |
583 | |
584 | int ProcessGroupGloo::RecvWork::sourceRank() const { |
585 | std::lock_guard<std::mutex> lock(mutex_); |
586 | return srcRank_; |
587 | } |
588 | |
589 | bool ProcessGroupGloo::RecvWork::wait(std::chrono::milliseconds timeout) { |
590 | bool recvCompleted = false; |
591 | std::exception_ptr exception{nullptr}; |
592 | try { |
593 | if (timeout == kNoTimeout) { |
594 | recvCompleted = buffer_->waitRecv(&srcRank_); |
595 | } else { |
596 | recvCompleted = buffer_->waitRecv(&srcRank_, timeout); |
597 | } |
598 | } catch (...) { |
599 | exception = std::current_exception(); |
600 | } |
601 | |
602 | // Completes the Work object and throws the exception. |
603 | finishAndThrow(exception); |
604 | return recvCompleted; |
605 | } |
606 | |
607 | void ProcessGroupGloo::RecvWork::abort() { |
608 | buffer_->abortWaitRecv(); |
609 | } |
610 | |
611 | ProcessGroupGloo::Options::Options(std::chrono::milliseconds timeout) |
612 | : Backend::Options(GLOO_BACKEND_NAME, timeout), threads(2) {} |
613 | |
614 | namespace { |
615 | |
616 | void socketInitialize() { |
617 | #ifdef _WIN32 |
618 | ::gloo::init_winsock(); |
619 | #endif |
620 | } |
621 | |
622 | // Gloo assumes that this machine's hostname can always be resolved |
623 | // to an address. If it doesn't it throws a runtime error saying |
624 | // that it can't be resolved. Instead of catching it, we choose |
625 | // to proactively check if an address can be resolved, so we can |
626 | // gracefully fall back to an alternative if it doesn't. |
627 | bool doesHostnameResolveToUsableAddress(const std::string& hostname) { |
628 | socketInitialize(); |
629 | struct addrinfo hints {}; |
630 | memset(&hints, 0, sizeof(hints)); |
631 | hints.ai_family = AF_UNSPEC; |
632 | hints.ai_socktype = SOCK_STREAM; |
633 | struct addrinfo* result = nullptr; |
634 | auto rv = getaddrinfo(hostname.c_str(), nullptr, &hints, &result); |
635 | if (rv < 0) { |
636 | return false; |
637 | } |
638 | struct addrinfo* rp = nullptr; |
639 | for (rp = result; rp != nullptr; rp = rp->ai_next) { |
640 | auto fd = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); |
641 | if (fd == -1) { |
642 | continue; |
643 | } |
644 | rv = bind(fd, rp->ai_addr, rp->ai_addrlen); |
645 | #ifdef _WIN32 |
646 | closesocket(fd); |
647 | #else |
648 | close(fd); |
649 | #endif |
650 | if (rv == -1) { |
651 | continue; |
652 | } |
653 | break; |
654 | } |
655 | freeaddrinfo(result); |
656 | return rp != nullptr; |
657 | } |
658 | |
659 | } // namespace |
660 | |
661 | std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo:: |
662 | createDeviceForInterface(const std::string& interface_name) { |
663 | return ::c10d::GlooDeviceFactory::makeDeviceForInterface(interface_name); |
664 | } |
665 | |
666 | std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo:: |
667 | createDeviceForHostname(const std::string& hostname) { |
668 | TORCH_CHECK( |
669 | doesHostnameResolveToUsableAddress(hostname), |
670 | "Cannot resolve " , |
671 | hostname, |
672 | " to a (local) address" ); |
673 | return ::c10d::GlooDeviceFactory::makeDeviceForHostname(hostname); |
674 | } |
675 | |
676 | #if defined(__linux__) || defined(_WIN32) |
677 | std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo:: |
678 | createDefaultDevice() { |
679 | // Use the hostname to resolve the network address to |
680 | // use. Note: if the hostname does not resolve to an address (e.g. |
681 | // because of misconfigured /etc/hosts file), this will not work. |
682 | socketInitialize(); |
683 | std::array<char, HOST_NAME_MAX> hostname{}; |
684 | auto rv = gethostname(hostname.data(), HOST_NAME_MAX); |
685 | if (rv != 0) { |
686 | throw std::system_error(errno, std::system_category()); |
687 | } |
688 | |
689 | // Use this machine's hostname if it resolves to an address. |
690 | if (doesHostnameResolveToUsableAddress(hostname.data())) { |
691 | return ::c10d::GlooDeviceFactory::makeDeviceForHostname(hostname.data()); |
692 | } |
693 | |
694 | // Otherwise, use the loopback address. |
695 | TORCH_WARN_ONCE( |
696 | "Unable to resolve hostname to a (local) address. " , |
697 | "Using the loopback address as fallback. " , |
698 | "Manually set the network interface to bind to with GLOO_SOCKET_IFNAME." ); |
699 | return createDeviceForHostname(kLoopbackAddress); |
700 | } |
701 | #endif |
702 | |
703 | #ifdef __APPLE__ |
704 | std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo:: |
705 | createDefaultDevice() { |
706 | // Use the hostname to resolve the network address to |
707 | // use. Note: if the hostname does not resolve to an address (e.g. |
708 | // because of misconfigured /etc/hosts file), this will not work. |
709 | const auto hostNameMax = sysconf(_SC_HOST_NAME_MAX); |
710 | auto hostname = std::unique_ptr<char[]>(new char[hostNameMax]); |
711 | auto rv = gethostname(hostname.get(), hostNameMax); |
712 | if (rv != 0) { |
713 | throw std::system_error(errno, std::system_category()); |
714 | } |
715 | |
716 | // Use this machine's hostname if it resolves to an address. |
717 | if (doesHostnameResolveToUsableAddress(hostname.get())) { |
718 | return ::c10d::GlooDeviceFactory::makeDeviceForHostname(hostname.get()); |
719 | } |
720 | |
721 | // Otherwise, use the loopback address. |
722 | TORCH_WARN_ONCE( |
723 | "Unable to resolve hostname to a (local) address. " , |
724 | "Using the loopback address as fallback. " , |
725 | "Manually set the network interface to bind to with GLOO_SOCKET_IFNAME." ); |
726 | return createDeviceForHostname(kLoopbackAddress); |
727 | } |
728 | #endif |
729 | |
730 | ProcessGroupGloo::ProcessGroupGloo( |
731 | const c10::intrusive_ptr<Store>& store, |
732 | int rank, |
733 | int size, |
734 | c10::intrusive_ptr<Options> options) |
735 | : Backend(rank, size), |
736 | store_(new GlooStore(store)), |
737 | options_(options), |
738 | stop_(false), |
739 | collectiveCounter_(0) { |
740 | auto& devices = options->devices; |
741 | if (devices.empty()) { |
742 | TORCH_CHECK(false, "No device(s) specified" ); |
743 | } |
744 | |
745 | // Create and connect a context for every device. |
746 | // |
747 | // Note that the same device can be specified multiple times, either |
748 | // the same object, or the same logical device as different objects. |
749 | // Either mode is fine and only has performance implications. |
750 | // |
751 | // Using the same object multiple times means all contexts share a |
752 | // single I/O thread. If you use different objects for the same |
753 | // logical device they will have independent I/O threads. The latter |
754 | // option is needed if you have a fast NIC that cannot be saturated |
755 | // by a single I/O thread. |
756 | // |
757 | contexts_.reserve(options->devices.size()); |
758 | for (const auto i : c10::irange(options->devices.size())) { |
759 | auto context = std::make_shared<::gloo::rendezvous::Context>(rank_, size_); |
760 | auto store = ::gloo::rendezvous::PrefixStore(std::to_string(i), *store_); |
761 | context->setTimeout(options->timeout); |
762 | context->connectFullMesh(store, options->devices[i]); |
763 | contexts_.push_back(std::move(context)); |
764 | } |
765 | |
766 | // Every worker thread stores the AsyncWork object it's currently |
767 | // working on in the workInProgress_ vector. It must have size equal |
768 | // to the number of workers such that they can simply index into it |
769 | // using the worker index they are started with. |
770 | workInProgress_.resize(options->threads); |
771 | |
772 | threads_.resize(options->threads); |
773 | for (const auto i : c10::irange(threads_.size())) { |
774 | threads_[i] = std::thread(&ProcessGroupGloo::runLoop, this, i); |
775 | } |
776 | |
777 | init(); |
778 | } |
779 | |
780 | ProcessGroupGloo::~ProcessGroupGloo() { |
781 | std::unique_lock<std::mutex> lock(workMutex_); |
782 | workConsumeCV_.wait(lock, [&] { return workQueue_.empty(); }); |
783 | |
784 | // Queue is empty, signal stop |
785 | stop_ = true; |
786 | |
787 | // Release lock to allow threads to terminate |
788 | lock.unlock(); |
789 | |
790 | workProduceCV_.notify_all(); |
791 | |
792 | // Wait for worker threads to terminate |
793 | for (auto& thread : threads_) { |
794 | thread.join(); |
795 | } |
796 | } |
797 | |
798 | uint32_t ProcessGroupGloo::nextTag() { |
799 | return collectiveCounter_++; |
800 | } |
801 | |
802 | std::shared_ptr<::gloo::Context> ProcessGroupGloo::getContext(uint32_t tag) { |
803 | return contexts_[tag % contexts_.size()]; |
804 | } |
805 | |
806 | void ProcessGroupGloo::runLoop(int workerIndex) { |
807 | std::unique_lock<std::mutex> lock(workMutex_); |
808 | |
809 | while (!stop_) { |
810 | if (workQueue_.empty()) { |
811 | workProduceCV_.wait(lock); |
812 | continue; |
813 | } |
814 | |
815 | auto work = std::move(workQueue_.front()); |
816 | workQueue_.pop_front(); |
817 | workInProgress_[workerIndex] = work; |
818 | lock.unlock(); |
819 | |
820 | // Notify after releasing the lock so that the waiter |
821 | // does not immediately block. |
822 | workConsumeCV_.notify_one(); |
823 | |
824 | AsyncWork::execute(std::move(work)); |
825 | lock.lock(); |
826 | workInProgress_[workerIndex].reset(); |
827 | } |
828 | } |
829 | |
830 | void ProcessGroupGloo::enqueue(c10::intrusive_ptr<AsyncWork> work) { |
831 | std::unique_lock<std::mutex> lock(workMutex_); |
832 | // Bump collective counter |
833 | if (sequenceNum_) { |
834 | sequenceNum_->increment(); |
835 | } |
836 | workQueue_.push_back(std::move(work)); |
837 | lock.unlock(); |
838 | |
839 | // Notify after releasing the lock so that the waiter |
840 | // does not immediately block. |
841 | workProduceCV_.notify_one(); |
842 | } |
843 | |
844 | namespace { |
845 | |
846 | class AsyncBroadcastWork : public ProcessGroupGloo::AsyncWork { |
847 | public: |
848 | AsyncBroadcastWork( |
849 | const std::shared_ptr<gloo::Context>& context, |
850 | std::vector<at::Tensor>& inputs, |
851 | int rootRank, |
852 | int rootTensor, |
853 | uint32_t tag) |
854 | : ProcessGroupGloo::AsyncWork({inputs}, "gloo:broadcast" , inputs), |
855 | context(context), |
856 | inputs(inputs), |
857 | rootRank(rootRank), |
858 | rootTensor(rootTensor), |
859 | tag(tag) {} |
860 | |
861 | std::shared_ptr<gloo::Context> context; |
862 | std::vector<at::Tensor> inputs; |
863 | const int rootRank; |
864 | const int rootTensor; |
865 | const uint32_t tag; |
866 | |
867 | void broadcast(at::Tensor& tensor) { |
868 | const auto& scalarType = tensor.scalar_type(); |
869 | gloo::BroadcastOptions opts(context); |
870 | opts.setRoot(rootRank); |
871 | opts.setTag(tag); |
872 | GENERATE_ALL_TYPES(scalarType, setOutput, opts, tensor); |
873 | gloo::broadcast(opts); |
874 | } |
875 | |
876 | void run() override { |
877 | broadcast(inputs[rootTensor]); |
878 | |
879 | // Copy to non-root tensors |
880 | for (const auto i : c10::irange(inputs.size())) { |
881 | if (i == static_cast<size_t>(rootTensor)) { |
882 | continue; |
883 | } |
884 | inputs[i].copy_(inputs[rootTensor]); |
885 | } |
886 | } |
887 | }; |
888 | |
889 | class AsyncBroadcastCUDAWork : public AsyncBroadcastWork { |
890 | public: |
891 | AsyncBroadcastCUDAWork( |
892 | const std::shared_ptr<gloo::Context>& context, |
893 | std::vector<at::Tensor>& inputs, |
894 | int rootRank, |
895 | int rootTensor, |
896 | uint32_t tag) |
897 | : AsyncBroadcastWork(context, inputs, rootRank, rootTensor, tag) { |
898 | initializeStreamsEvents(inputs, streams, events); |
899 | |
900 | // Create pinned host side tensors. |
901 | tmp = pinnedLike(inputs[rootTensor]); |
902 | c10::OptionalStreamGuard guard; |
903 | if (context->rank == rootRank) { |
904 | guard.reset_stream(streams[rootTensor]); |
905 | tmp.copy_(inputs[rootTensor], /* non_blocking */ true); |
906 | } |
907 | } |
908 | |
909 | void run() override { |
910 | // Synchronize with copy operation if applicable. |
911 | if (context->rank == rootRank) { |
912 | streams[rootTensor].synchronize(); |
913 | } |
914 | |
915 | // Run broadcast on host side tensors. |
916 | broadcast(tmp); |
917 | |
918 | // Kick off copy back to the CUDA tensors. |
919 | c10::OptionalStreamGuard guard; |
920 | for (const auto i : c10::irange(inputs.size())) { |
921 | guard.reset_stream(streams[i]); |
922 | inputs[i].copy_(tmp, /* non_blocking */ true); |
923 | events[i].record(streams[i]); |
924 | } |
925 | } |
926 | |
927 | void synchronize() override { |
928 | // Synchronize with the copy back to CUDA tensors. |
929 | for (const auto i : c10::irange(inputs.size())) { |
930 | c10::Device device = inputs[i].device(); |
931 | events[i].block( |
932 | c10::impl::VirtualGuardImpl(device.type()).getStream(device)); |
933 | } |
934 | } |
935 | |
936 | at::Tensor tmp; |
937 | std::vector<c10::Stream> streams; |
938 | std::vector<c10::Event> events; |
939 | }; |
940 | |
941 | } // namespace |
942 | |
943 | c10::intrusive_ptr<Work> ProcessGroupGloo::broadcast( |
944 | std::vector<at::Tensor>& inputs, |
945 | const BroadcastOptions& opts) { |
946 | static auto invalidArgument = [](const std::string& msg) { |
947 | TORCH_CHECK(false, "ProcessGroupGloo::broadcast: " + msg); |
948 | }; |
949 | |
950 | assertRootRank(invalidArgument, opts.rootRank, size_); |
951 | assertRootTensor(invalidArgument, opts.rootTensor, inputs.size()); |
952 | assertDense(invalidArgument, inputs); |
953 | assertTypeAndSizesMatch(invalidArgument, inputs); |
954 | |
955 | const auto& device = inputs[0].device(); |
956 | switch (device.type()) { |
957 | case at::kCPU: |
958 | break; |
959 | case at::kCUDA: |
960 | // If the user gave us a CUDA tensor then CUDA must be loaded. |
961 | TORCH_INTERNAL_ASSERT(at::hasCUDA()); |
962 | break; |
963 | default: |
964 | invalidArgument(c10::str("unsupported device type " , device.type())); |
965 | } |
966 | |
967 | c10::intrusive_ptr<AsyncBroadcastWork> work; |
968 | auto tag = nextTag(); |
969 | auto context = getContext(tag); |
970 | if (device.type() == at::kCPU) { |
971 | work = c10::make_intrusive<AsyncBroadcastWork>( |
972 | std::move(context), inputs, opts.rootRank, opts.rootTensor, tag); |
973 | } else if (device.type() == at::kCUDA) { |
974 | work = c10::make_intrusive<AsyncBroadcastCUDAWork>( |
975 | std::move(context), inputs, opts.rootRank, opts.rootTensor, tag); |
976 | } else { |
977 | TORCH_CHECK(false, "Invalid backend" ); |
978 | } |
979 | |
980 | enqueue(work); |
981 | return work; |
982 | } |
983 | |
984 | namespace { |
985 | |
986 | class AsyncAllreduceWork : public ProcessGroupGloo::AsyncWork { |
987 | public: |
988 | AsyncAllreduceWork( |
989 | const std::shared_ptr<gloo::Context>& context, |
990 | std::vector<at::Tensor>& inputs, |
991 | ReduceOp reduceOp, |
992 | uint32_t tag) |
993 | : ProcessGroupGloo::AsyncWork({inputs}, "gloo:all_reduce" , inputs), |
994 | context(context), |
995 | inputs(inputs), |
996 | reduceOp(reduceOp), |
997 | tag(tag) {} |
998 | |
999 | std::shared_ptr<gloo::Context> context; |
1000 | std::vector<at::Tensor> inputs; |
1001 | const ReduceOp reduceOp; |
1002 | const uint32_t tag; |
1003 | |
1004 | void allreduce(std::vector<at::Tensor>& tensors) { |
1005 | const auto& scalarType = tensors[0].scalar_type(); |
1006 | gloo::AllreduceOptions opts(context); |
1007 | opts.setReduceFunction(getFunction(scalarType, reduceOp)); |
1008 | opts.setTag(tag); |
1009 | GENERATE_ALL_TYPES(scalarType, setOutputs, opts, tensors); |
1010 | gloo::allreduce(opts); |
1011 | } |
1012 | |
1013 | void run() override { |
1014 | allreduce(inputs); |
1015 | } |
1016 | |
1017 | template <typename T> |
1018 | void getFunction(gloo::AllreduceOptions::Func& fn, const ReduceOp op) { |
1019 | fn = toFunction<T>(op); |
1020 | } |
1021 | |
1022 | gloo::AllreduceOptions::Func getFunction( |
1023 | const at::ScalarType& dtype, |
1024 | const ReduceOp op) { |
1025 | gloo::AllreduceOptions::Func fn; |
1026 | GENERATE_ALL_TYPES(dtype, getFunction, fn, op); |
1027 | return fn; |
1028 | } |
1029 | }; |
1030 | |
1031 | class AsyncAllreduceCoalescedWork : public AsyncAllreduceWork { |
1032 | public: |
1033 | AsyncAllreduceCoalescedWork( |
1034 | const std::shared_ptr<gloo::Context>& context, |
1035 | std::vector<at::Tensor>& inputs, |
1036 | ReduceOp reduceOp, |
1037 | uint32_t tag) |
1038 | : AsyncAllreduceWork(context, inputs, reduceOp, tag) {} |
1039 | |
1040 | void run() override { |
1041 | allreduceCoalesced(inputs); |
1042 | } |
1043 | |
1044 | private: |
1045 | void allreduceCoalesced(std::vector<at::Tensor>& tensors) { |
1046 | // reduce coalesced, flattened tensors. |
1047 | at::Tensor coalescedTensor = flattenDenseTensors(tensors); |
1048 | std::vector<at::Tensor> allreduceInput = {coalescedTensor}; |
1049 | allreduce(allreduceInput); |
1050 | |
1051 | // separate and reshape tensors. |
1052 | size_t offset = 0; |
1053 | for (at::Tensor& tensor : tensors) { |
1054 | const int64_t tensorNumel = tensor.numel(); |
1055 | const c10::IntArrayRef tensorShape = tensor.sizes(); |
1056 | tensor.copy_(coalescedTensor.slice(0, offset, offset + tensorNumel) |
1057 | .view(tensorShape)); |
1058 | offset += tensorNumel; |
1059 | } |
1060 | } |
1061 | }; |
1062 | |
1063 | class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork { |
1064 | public: |
1065 | AsyncSparseAllreduceWork( |
1066 | const std::shared_ptr<gloo::Context>& context, |
1067 | std::vector<at::Tensor>& inputs, |
1068 | uint32_t tag) |
1069 | : ProcessGroupGloo::AsyncWork({inputs}, "gloo:sparse_all_reduce" , inputs), |
1070 | context(context), |
1071 | inputs(inputs), |
1072 | tag(tag) {} |
1073 | |
1074 | std::shared_ptr<gloo::Context> context; |
1075 | std::vector<at::Tensor> inputs; |
1076 | const uint32_t tag; |
1077 | |
1078 | // We share dimensionality about the sparse tensors before collecting |
1079 | // their contents. We assume here that the maximum number of sparse |
1080 | // and dense dimensions is 4. This is stored in a contiguous piece of |
1081 | // memory so that we can easily run allgather on it. |
1082 | // |
1083 | // The layout of this memory is as follows: |
1084 | // |
1085 | // - [0:4]: sparse dims |
1086 | // - [4:8]: dense dims |
1087 | // - [8]: nnz |
1088 | // |
1089 | class SparseTensorMetadata { |
1090 | public: |
1091 | static constexpr auto dim = 9; |
1092 | |
1093 | // Construct from an existing metadata tensor to facilitate structured |
1094 | // access to metadata from peers, after gathering it. |
1095 | explicit SparseTensorMetadata(at::Tensor metadata) |
1096 | : metadata_(metadata), data_(metadata_.data_ptr<int64_t>()) { |
1097 | AT_ASSERT(metadata.scalar_type() == at::kLong); |
1098 | AT_ASSERT(metadata.dim() == 1); |
1099 | AT_ASSERT(metadata.size(0) == dim); |
1100 | } |
1101 | |
1102 | // Populate the metadata. |
1103 | void populate_from_sparse_tensor(const at::Tensor& tensor) { |
1104 | const auto sparse_dim = tensor.sparse_dim(); |
1105 | AT_ASSERT(sparse_dim <= 4); |
1106 | for (const auto i : c10::irange(4)) { |
1107 | if (i < sparse_dim) { |
1108 | data_[i] = tensor.size(i); |
1109 | } |
1110 | } |
1111 | const auto dense_dim = tensor.dense_dim(); |
1112 | AT_ASSERT(dense_dim <= 4); |
1113 | for (const auto i : c10::irange(4)) { |
1114 | if (i < dense_dim) { |
1115 | data_[i + 4] = tensor.size(sparse_dim + i); |
1116 | } |
1117 | } |
1118 | data_[8] = tensor._nnz(); |
1119 | } |
1120 | |
1121 | std::vector<int64_t> sizes() const { |
1122 | std::vector<int64_t> sizes; |
1123 | // Sparse sizes |
1124 | for (const auto i : c10::irange(4)) { |
1125 | if (data_[i] <= 0) { |
1126 | break; |
1127 | } |
1128 | sizes.push_back(data_[i]); |
1129 | } |
1130 | // Dense sizes |
1131 | for (const auto i : c10::irange(4, 8)) { |
1132 | if (data_[i] <= 0) { |
1133 | break; |
1134 | } |
1135 | sizes.push_back(data_[i]); |
1136 | } |
1137 | return sizes; |
1138 | } |
1139 | |
1140 | int64_t nnz() const { |
1141 | return data_[8]; |
1142 | } |
1143 | |
1144 | protected: |
1145 | at::Tensor metadata_; |
1146 | int64_t* data_; |
1147 | }; |
1148 | |
1149 | // Sparse allreduce is implemented with allgather on indices and values. |
1150 | // Every process then sums the resulting sparse tensors locally. |
1151 | // The nnz for sparse tensors may be different across processes, so first |
1152 | // we run allgather on the nnz, and then allgather with max(nnz). |
1153 | at::Tensor allreduce(std::vector<at::Tensor>& tensors) { |
1154 | // TODO: This is a massive hack! There is some confusion about |
1155 | // Variable/Tensor inside the body of this function. Turning off |
1156 | // grad smooths over the confusion for now. This fixes |
1157 | // test/test_c10d_gloo.py ProcessGroupGlooTest.test_sparse_allreduce_basics |
1158 | // |
1159 | // The correct fix is to stop allocating tensors that are not variables, |
1160 | // but to conveniently do this c10d must depend on torch not ATen |
1161 | at::AutoDispatchBelowAutograd guard; |
1162 | auto input = tensors[0]; |
1163 | |
1164 | // Perform local reduction if we have multiple inputs. |
1165 | for (const auto i : c10::irange(1, tensors.size())) { |
1166 | input += tensors[i]; |
1167 | } |
1168 | |
1169 | // Need to coalesce before we can access indices and values. |
1170 | input = input.coalesce(); |
1171 | |
1172 | // Gather metadata information from all ranks. |
1173 | auto metadata = allgather_metadata(input); |
1174 | |
1175 | // Sanity check dimensionality across ranks. |
1176 | { |
1177 | const auto expected = metadata[context->rank].sizes(); |
1178 | for (const auto i : c10::irange(context->size)) { |
1179 | if (i == context->rank) { |
1180 | continue; |
1181 | } |
1182 | const auto actual = metadata[i].sizes(); |
1183 | TORCH_CHECK(actual == expected, "Sparse dimensions do not match" ); |
1184 | } |
1185 | } |
1186 | |
1187 | // Gather all indices and all values. |
1188 | auto indices = allgather_indices(input, metadata); |
1189 | auto values = allgather_values(input, metadata); |
1190 | |
1191 | // Perform global reduction. |
1192 | AT_ASSERT(static_cast<int>(indices.size()) == context->size); |
1193 | AT_ASSERT(static_cast<int>(values.size()) == context->size); |
1194 | auto output = at::sparse_coo_tensor( |
1195 | indices[0], values[0], input.sizes(), input.options()); |
1196 | for (const auto i : c10::irange(1, context->size)) { |
1197 | output += at::sparse_coo_tensor( |
1198 | indices[i], values[i], input.sizes(), input.options()); |
1199 | } |
1200 | |
1201 | // Coalesce for good measure. |
1202 | return output.coalesce(); |
1203 | } |
1204 | |
1205 | void run() override { |
1206 | auto output = allreduce(inputs); |
1207 | |
1208 | // This copy is needed when we run a multi-gpu version of reduce (multiple |
1209 | // inputs per rank). |
1210 | for (const auto i : c10::irange(inputs.size())) { |
1211 | inputs[i].copy_(output); |
1212 | } |
1213 | } |
1214 | |
1215 | private: |
1216 | std::vector<SparseTensorMetadata> allgather_metadata( |
1217 | const at::Tensor& tensor) { |
1218 | auto buffer = |
1219 | at::zeros({context->size, SparseTensorMetadata::dim}, at::kLong); |
1220 | |
1221 | // Prepare metadata vector (1 entry per rank) |
1222 | std::vector<SparseTensorMetadata> metadata; |
1223 | metadata.reserve(context->size); |
1224 | for (const auto i : c10::irange(context->size)) { |
1225 | metadata.emplace_back(buffer.select(0, i)); |
1226 | } |
1227 | |
1228 | // Populate data for this rank |
1229 | metadata[context->rank].populate_from_sparse_tensor(tensor); |
1230 | |
1231 | // Allgather metadata |
1232 | gloo::AllgatherOptions opts(context); |
1233 | opts.setOutput(buffer.data_ptr<int64_t>(), buffer.numel()); |
1234 | opts.setTag(tag); |
1235 | gloo::allgather(opts); |
1236 | |
1237 | return metadata; |
1238 | } |
1239 | |
1240 | std::vector<at::Tensor> allgather_indices( |
1241 | const at::Tensor& tensor, |
1242 | const std::vector<SparseTensorMetadata>& metadata) { |
1243 | const auto sparseDim = tensor.sparse_dim(); |
1244 | |
1245 | std::vector<size_t> counts(context->size); |
1246 | int64_t totalSize = 0; |
1247 | for (const auto i : c10::irange(metadata.size())) { |
1248 | counts[i] = metadata[i].nnz() * sparseDim; |
1249 | totalSize += counts[i]; |
1250 | } |
1251 | |
1252 | auto output = at::empty({totalSize}, at::kLong); |
1253 | |
1254 | // tensors copied from cuda may not be contiguous, get a contiguous |
1255 | // tensor before use its data_ptr |
1256 | auto input = tensor.indices().contiguous(); |
1257 | |
1258 | // Allgatherv indices. |
1259 | gloo::AllgathervOptions opts(context); |
1260 | opts.setInput(input.data_ptr<int64_t>(), input.numel()); |
1261 | opts.setOutput(output.data_ptr<int64_t>(), counts); |
1262 | opts.setTag(tag); |
1263 | gloo::allgatherv(opts); |
1264 | |
1265 | // Compile indices tensor per rank. |
1266 | std::vector<at::Tensor> indices; |
1267 | indices.reserve(metadata.size()); |
1268 | size_t offset = 0; |
1269 | for (const auto& i : metadata) { |
1270 | const auto nnz = i.nnz(); |
1271 | const auto numel = sparseDim * nnz; |
1272 | indices.push_back( |
1273 | output.narrow(0, offset, numel).reshape({sparseDim, nnz})); |
1274 | offset += numel; |
1275 | } |
1276 | |
1277 | return indices; |
1278 | } |
1279 | |
1280 | std::vector<at::Tensor> allgather_values( |
1281 | const at::Tensor& tensor, |
1282 | const std::vector<SparseTensorMetadata>& metadata) { |
1283 | // There are nnz #dense_dim()-dimensional tensors per rank. |
1284 | const auto valueShape = tensor.sizes().slice(tensor.sparse_dim()); |
1285 | size_t denseNumel = 1; |
1286 | for (auto dim : valueShape) { |
1287 | denseNumel *= dim; |
1288 | } |
1289 | |
1290 | std::vector<size_t> counts(context->size); |
1291 | int64_t totalSize = 0; |
1292 | for (const auto i : c10::irange(metadata.size())) { |
1293 | counts[i] = metadata[i].nnz() * denseNumel; |
1294 | totalSize += counts[i]; |
1295 | } |
1296 | |
1297 | auto output = at::empty({totalSize}, tensor.scalar_type()); |
1298 | |
1299 | // Allgatherv indices. |
1300 | gloo::AllgathervOptions opts(context); |
1301 | // tensors copied from cuda may not be contiguous, get a contiguous |
1302 | // tensor before use its data_ptr |
1303 | at::Tensor valueTensor = tensor.values().contiguous(); |
1304 | GENERATE_ALL_TYPES(valueTensor.scalar_type(), setInput, opts, valueTensor); |
1305 | GENERATE_ALL_TYPES( |
1306 | valueTensor.scalar_type(), setOutput, opts, output, counts); |
1307 | opts.setTag(tag); |
1308 | gloo::allgatherv(opts); |
1309 | |
1310 | // Compile values tensor per rank. |
1311 | std::vector<at::Tensor> values; |
1312 | values.reserve(metadata.size()); |
1313 | size_t offset = 0; |
1314 | for (const auto& i : metadata) { |
1315 | const auto nnz = i.nnz(); |
1316 | const auto numel = denseNumel * nnz; |
1317 | auto tensorShape = std::vector<int64_t>({(int64_t)nnz}); |
1318 | std::copy( |
1319 | valueShape.begin(), |
1320 | valueShape.end(), |
1321 | std::back_inserter(tensorShape)); |
1322 | values.push_back(output.narrow(0, offset, numel).reshape(tensorShape)); |
1323 | offset += numel; |
1324 | } |
1325 | |
1326 | return values; |
1327 | } |
1328 | }; |
1329 | |
1330 | class AsyncAllreduceCUDAWork : public AsyncAllreduceWork { |
1331 | public: |
1332 | AsyncAllreduceCUDAWork( |
1333 | const std::shared_ptr<gloo::Context>& context, |
1334 | std::vector<at::Tensor>& inputs, |
1335 | ReduceOp reduceOp, |
1336 | uint32_t tag) |
1337 | : AsyncAllreduceWork(context, inputs, reduceOp, tag) { |
1338 | initializeStreamsEvents(inputs, streams, events); |
1339 | |
1340 | // Kick off copy from CUDA tensors to pinned CPU tensors. |
1341 | tmp.reserve(inputs.size()); |
1342 | c10::OptionalStreamGuard guard; |
1343 | for (const auto i : c10::irange(inputs.size())) { |
1344 | guard.reset_stream(streams[i]); |
1345 | tmp.push_back(pinnedLike(inputs[i]).copy_(inputs[i], true)); |
1346 | } |
1347 | } |
1348 | |
1349 | void run() override { |
1350 | // Synchronize with copy operations. |
1351 | for (const auto i : c10::irange(inputs.size())) { |
1352 | streams[i].synchronize(); |
1353 | } |
1354 | |
1355 | // Run allreduce on host side tensors. |
1356 | allreduce(tmp); |
1357 | |
1358 | c10::OptionalStreamGuard guard; |
1359 | for (const auto i : c10::irange(inputs.size())) { |
1360 | guard.reset_stream(streams[i]); |
1361 | inputs[i].copy_(tmp[i], /* non_blocking */ true); |
1362 | events[i].record(streams[i]); |
1363 | } |
1364 | } |
1365 | |
1366 | void synchronize() override { |
1367 | // Synchronize with the copy back to CUDA tensors. |
1368 | for (const auto i : c10::irange(inputs.size())) { |
1369 | c10::Device device = inputs[i].device(); |
1370 | events[i].block( |
1371 | c10::impl::VirtualGuardImpl(device.type()).getStream(device)); |
1372 | } |
1373 | } |
1374 | |
1375 | std::vector<at::Tensor> tmp; |
1376 | std::vector<c10::Stream> streams; |
1377 | std::vector<c10::Event> events; |
1378 | }; |
1379 | |
1380 | class AsyncSparseAllreduceCUDAWork : public AsyncSparseAllreduceWork { |
1381 | public: |
1382 | AsyncSparseAllreduceCUDAWork( |
1383 | const std::shared_ptr<gloo::Context>& context, |
1384 | std::vector<at::Tensor>& inputs, |
1385 | uint32_t tag) |
1386 | : AsyncSparseAllreduceWork(context, inputs, tag) { |
1387 | initializeStreamsEvents(inputs, streams, events); |
1388 | |
1389 | // Kick off copy from CUDA tensors to CPU tensors. |
1390 | // Note that both coalescing the sparse tensor and copying it to CPU |
1391 | // memory must be performed asynchronously, or we block the caller. |
1392 | tmp.reserve(inputs.size()); |
1393 | c10::OptionalStreamGuard guard; |
1394 | for (const auto i : c10::irange(inputs.size())) { |
1395 | guard.reset_stream(streams[i]); |
1396 | tmp.push_back( |
1397 | inputs[i].coalesce().to(at::DeviceType::CPU, /*non_blocking=*/true)); |
1398 | } |
1399 | } |
1400 | |
1401 | void run() override { |
1402 | // Synchronize with copy operations. |
1403 | for (const auto i : c10::irange(inputs.size())) { |
1404 | streams[i].synchronize(); |
1405 | } |
1406 | |
1407 | // Run allreduce on host side tensors. |
1408 | auto output = allreduce(tmp); |
1409 | |
1410 | // Kick off copy back to the CUDA tensors. |
1411 | c10::OptionalStreamGuard guard; |
1412 | for (const auto i : c10::irange(inputs.size())) { |
1413 | guard.reset_stream(streams[i]); |
1414 | inputs[i].copy_(output, /*non_blocking=*/true); |
1415 | events[i].record(streams[i]); |
1416 | } |
1417 | } |
1418 | |
1419 | void synchronize() override { |
1420 | // Synchronize with the copy back to CUDA tensors. |
1421 | for (const auto i : c10::irange(inputs.size())) { |
1422 | c10::Device device = inputs[i].device(); |
1423 | events[i].block( |
1424 | c10::impl::VirtualGuardImpl(device.type()).getStream(device)); |
1425 | } |
1426 | } |
1427 | |
1428 | std::vector<at::Tensor> tmp; |
1429 | std::vector<c10::Stream> streams; |
1430 | std::vector<c10::Event> events; |
1431 | }; |
1432 | |
1433 | } // namespace |
1434 | |
1435 | c10::intrusive_ptr<Work> ProcessGroupGloo::allreduce( |
1436 | std::vector<at::Tensor>& inputs, |
1437 | const AllreduceOptions& opts) { |
1438 | static auto invalidArgument = [](const std::string& msg) { |
1439 | TORCH_CHECK(false, "ProcessGroupGloo::allreduce: " + msg); |
1440 | }; |
1441 | |
1442 | assertNonEmpty(invalidArgument, inputs); |
1443 | assertLayoutMatch(invalidArgument, inputs); |
1444 | assertTypeAndSizesMatch(invalidArgument, inputs); |
1445 | |
1446 | const auto& device = inputs[0].device(); |
1447 | switch (device.type()) { |
1448 | case at::kCPU: |
1449 | break; |
1450 | case at::kCUDA: |
1451 | // If the user gave us a CUDA tensor then CUDA must be loaded. |
1452 | TORCH_INTERNAL_ASSERT(at::hasCUDA()); |
1453 | break; |
1454 | default: |
1455 | invalidArgument(c10::str("unsupported device type " , device.type())); |
1456 | } |
1457 | |
1458 | const auto& layout = inputs[0].layout(); |
1459 | if (layout == c10::kSparse && opts.reduceOp != ReduceOp::SUM) { |
1460 | invalidArgument( |
1461 | "unsupported reduction operation " |
1462 | "(allreduce of sparse tensors only works with ReduceOp.SUM)" ); |
1463 | } |
1464 | |
1465 | c10::intrusive_ptr<AsyncWork> work; |
1466 | auto tag = nextTag(); |
1467 | auto context = getContext(tag); |
1468 | if (device.type() == at::kCPU) { |
1469 | if (layout == c10::kStrided) { |
1470 | work = c10::make_intrusive<AsyncAllreduceWork>( |
1471 | std::move(context), inputs, opts.reduceOp, tag); |
1472 | } else if (layout == c10::kSparse) { |
1473 | work = c10::make_intrusive<AsyncSparseAllreduceWork>( |
1474 | std::move(context), inputs, tag); |
1475 | } else { |
1476 | invalidArgument("unsupported layout" ); |
1477 | } |
1478 | } else if (device.type() == at::kCUDA) { |
1479 | if (layout == c10::kStrided) { |
1480 | work = c10::make_intrusive<AsyncAllreduceCUDAWork>( |
1481 | std::move(context), inputs, opts.reduceOp, tag); |
1482 | } else if (layout == c10::kSparse) { |
1483 | work = c10::make_intrusive<AsyncSparseAllreduceCUDAWork>( |
1484 | std::move(context), inputs, tag); |
1485 | } else { |
1486 | invalidArgument("unsupported layout" ); |
1487 | } |
1488 | } else { |
1489 | TORCH_CHECK(false, "Invalid backend" ); |
1490 | } |
1491 | |
1492 | enqueue(work); |
1493 | return work; |
1494 | } |
1495 | |
1496 | c10::intrusive_ptr<Work> ProcessGroupGloo::allreduce_coalesced( |
1497 | std::vector<at::Tensor>& tensors, |
1498 | const AllreduceCoalescedOptions& opts) { |
1499 | static auto invalidArgument = [](const std::string& msg) { |
1500 | TORCH_CHECK(false, "ProcessGroupGloo::allreduce_coalesced: " + msg); |
1501 | }; |
1502 | assertNonEmpty(invalidArgument, tensors); |
1503 | |
1504 | // tensors will be flattened and concatenated (coalesced). This means that |
1505 | // input |
1506 | // tensors must have the same device, layout and type. |
1507 | assertLayoutMatch(invalidArgument, tensors); |
1508 | if (!std::all_of(tensors.begin(), tensors.end(), [&](at::Tensor& t) { |
1509 | return t.options().type_equal(tensors[0].options()); |
1510 | })) { |
1511 | invalidArgument("tensors must all have the same type" ); |
1512 | } |
1513 | if (!std::all_of(tensors.begin(), tensors.end(), [&](at::Tensor& t) { |
1514 | return t.device() == tensors[0].device(); |
1515 | })) { |
1516 | invalidArgument("tensors must all be on the same device" ); |
1517 | } |
1518 | |
1519 | const c10::Device& device = tensors[0].device(); |
1520 | const c10::Layout& layout = tensors[0].layout(); |
1521 | |
1522 | // invalid arguments are detected early here before any calls to nextTag() |
1523 | // which result in the collectiveCounter_ being incremented. |
1524 | switch (device.type()) { |
1525 | case c10::kCPU: |
1526 | break; |
1527 | default: |
1528 | invalidArgument(c10::str("unsupported device type " , device.type())); |
1529 | } |
1530 | |
1531 | switch (layout) { |
1532 | case c10::kStrided: |
1533 | break; |
1534 | default: |
1535 | invalidArgument("unsupported layout" ); |
1536 | } |
1537 | |
1538 | c10::intrusive_ptr<AsyncWork> work; |
1539 | const uint32_t tag = nextTag(); |
1540 | std::shared_ptr<gloo::Context> context = getContext(tag); |
1541 | if (device.type() == c10::kCPU) { |
1542 | if (layout == c10::kStrided) { |
1543 | work = c10::make_intrusive<AsyncAllreduceCoalescedWork>( |
1544 | std::move(context), tensors, opts.reduceOp, tag); |
1545 | } else { |
1546 | invalidArgument("unsupported layout" ); |
1547 | } |
1548 | } else { |
1549 | TORCH_CHECK(false, "Invalid backend" ); |
1550 | } |
1551 | enqueue(work); |
1552 | return work; |
1553 | } |
1554 | |
1555 | namespace { |
1556 | |
1557 | class AsyncReduceWork : public ProcessGroupGloo::AsyncWork { |
1558 | public: |
1559 | AsyncReduceWork( |
1560 | const std::shared_ptr<gloo::Context>& context, |
1561 | std::vector<at::Tensor>& inputs, |
1562 | int rootRank, |
1563 | int rootTensor, |
1564 | ReduceOp reduceOp, |
1565 | uint32_t tag) |
1566 | : ProcessGroupGloo::AsyncWork({inputs}, "gloo:reduce" , inputs), |
1567 | context(context), |
1568 | inputs(inputs), |
1569 | rootRank(rootRank), |
1570 | rootTensor(rootTensor), |
1571 | reduceOp(reduceOp), |
1572 | tag(tag) {} |
1573 | |
1574 | std::shared_ptr<gloo::Context> context; |
1575 | std::vector<at::Tensor> inputs; |
1576 | const int rootRank; |
1577 | const int rootTensor; |
1578 | const ReduceOp reduceOp; |
1579 | const uint32_t tag; |
1580 | |
1581 | void reduce(std::vector<at::Tensor>& tensors) { |
1582 | const auto& scalarType = tensors[0].scalar_type(); |
1583 | gloo::ReduceOptions opts(context); |
1584 | opts.setRoot(rootRank); |
1585 | opts.setTag(tag); |
1586 | opts.setReduceFunction(getFunction(scalarType, reduceOp)); |
1587 | GENERATE_ALL_TYPES(scalarType, setOutput, opts, tensors[0]); |
1588 | gloo::reduce(opts); |
1589 | } |
1590 | |
1591 | void run() override { |
1592 | reduce(inputs); |
1593 | } |
1594 | |
1595 | protected: |
1596 | template <typename T> |
1597 | void getFunction(gloo::ReduceOptions::Func& fn, const ReduceOp op) { |
1598 | fn = toFunction<T>(op); |
1599 | } |
1600 | |
1601 | gloo::ReduceOptions::Func getFunction( |
1602 | const at::ScalarType& dtype, |
1603 | const ReduceOp op) { |
1604 | gloo::ReduceOptions::Func fn; |
1605 | GENERATE_ALL_TYPES(dtype, getFunction, fn, op); |
1606 | return fn; |
1607 | } |
1608 | }; |
1609 | |
1610 | class AsyncReduceCUDAWork : public AsyncReduceWork { |
1611 | public: |
1612 | AsyncReduceCUDAWork( |
1613 | const std::shared_ptr<gloo::Context>& context, |
1614 | std::vector<at::Tensor>& inputs, |
1615 | int rootRank, |
1616 | int rootTensor, |
1617 | ReduceOp reduceOp, |
1618 | uint32_t tag) |
1619 | : AsyncReduceWork(context, inputs, rootRank, rootTensor, reduceOp, tag) { |
1620 | initializeStreamsEvents(inputs, streams, events); |
1621 | |
1622 | // Kick off copy from CUDA tensors to pinned CPU tensors. |
1623 | tmp.reserve(inputs.size()); |
1624 | c10::OptionalStreamGuard guard; |
1625 | for (const auto i : c10::irange(inputs.size())) { |
1626 | guard.reset_stream(streams[i]); |
1627 | tmp.push_back(pinnedLike(inputs[i]).copy_(inputs[i], true)); |
1628 | } |
1629 | } |
1630 | |
1631 | void run() override { |
1632 | // Synchronize with copy operations. |
1633 | for (const auto i : c10::irange(inputs.size())) { |
1634 | streams[i].synchronize(); |
1635 | } |
1636 | |
1637 | // Run reduce on host side tensors. |
1638 | reduce(tmp); |
1639 | |
1640 | // Kick off copy back to the CUDA tensors. |
1641 | c10::OptionalStreamGuard guard; |
1642 | for (const auto i : c10::irange(inputs.size())) { |
1643 | guard.reset_stream(streams[i]); |
1644 | inputs[i].copy_(tmp[i], /* non_blocking */ true); |
1645 | events[i].record(streams[i]); |
1646 | } |
1647 | } |
1648 | |
1649 | void synchronize() override { |
1650 | // Synchronize with the copy back to CUDA tensors. |
1651 | for (const auto i : c10::irange(inputs.size())) { |
1652 | c10::Device device = inputs[i].device(); |
1653 | events[i].block( |
1654 | c10::impl::VirtualGuardImpl(device.type()).getStream(device)); |
1655 | } |
1656 | } |
1657 | |
1658 | std::vector<at::Tensor> tmp; |
1659 | std::vector<c10::Stream> streams; |
1660 | std::vector<c10::Event> events; |
1661 | }; |
1662 | |
1663 | } // namespace |
1664 | |
1665 | c10::intrusive_ptr<Work> ProcessGroupGloo::reduce( |
1666 | std::vector<at::Tensor>& inputs, |
1667 | const ReduceOptions& opts) { |
1668 | static auto invalidArgument = [](const std::string& msg) { |
1669 | TORCH_CHECK(false, "ProcessGroupGloo::reduce: " + msg); |
1670 | }; |
1671 | |
1672 | assertRootRank(invalidArgument, opts.rootRank, size_); |
1673 | assertRootTensor(invalidArgument, opts.rootTensor, inputs.size()); |
1674 | assertSingleElement(invalidArgument, inputs); |
1675 | assertDense(invalidArgument, inputs); |
1676 | |
1677 | const auto& device = inputs[0].device(); |
1678 | switch (device.type()) { |
1679 | case at::kCPU: |
1680 | break; |
1681 | case at::kCUDA: |
1682 | // If the user gave us a CUDA tensor then CUDA must be loaded. |
1683 | TORCH_INTERNAL_ASSERT(at::hasCUDA()); |
1684 | break; |
1685 | default: |
1686 | invalidArgument(c10::str("unsupported device type " , device.type())); |
1687 | } |
1688 | |
1689 | c10::intrusive_ptr<AsyncReduceWork> work; |
1690 | auto tag = nextTag(); |
1691 | auto context = getContext(tag); |
1692 | if (device.type() == at::kCPU) { |
1693 | work = c10::make_intrusive<AsyncReduceWork>( |
1694 | std::move(context), |
1695 | inputs, |
1696 | opts.rootRank, |
1697 | opts.rootTensor, |
1698 | opts.reduceOp, |
1699 | tag); |
1700 | } else if (device.type() == at::kCUDA) { |
1701 | work = c10::make_intrusive<AsyncReduceCUDAWork>( |
1702 | std::move(context), |
1703 | inputs, |
1704 | opts.rootRank, |
1705 | opts.rootTensor, |
1706 | opts.reduceOp, |
1707 | tag); |
1708 | } else { |
1709 | TORCH_CHECK(false, "Invalid backend" ); |
1710 | } |
1711 | enqueue(work); |
1712 | return work; |
1713 | } |
1714 | |
1715 | namespace { |
1716 | |
1717 | class AsyncAllgatherWork : public ProcessGroupGloo::AsyncWork { |
1718 | public: |
1719 | AsyncAllgatherWork( |
1720 | const std::shared_ptr<gloo::Context>& context, |
1721 | std::vector<std::vector<at::Tensor>>& outputs, |
1722 | std::vector<at::Tensor>& inputs, |
1723 | uint32_t tag) |
1724 | : ProcessGroupGloo::AsyncWork(outputs, "gloo:all_gather" , inputs), |
1725 | context(context), |
1726 | outputs(outputs), |
1727 | inputs(inputs), |
1728 | tag(tag) {} |
1729 | |
1730 | std::shared_ptr<gloo::Context> context; |
1731 | std::vector<std::vector<at::Tensor>> outputs; |
1732 | std::vector<at::Tensor> inputs; |
1733 | const uint32_t tag; |
1734 | |
1735 | void allgather( |
1736 | std::vector<std::vector<at::Tensor>>& outputs, |
1737 | std::vector<at::Tensor>& inputs) { |
1738 | const auto& scalarType = inputs[0].scalar_type(); |
1739 | gloo::AllgatherOptions opts(context); |
1740 | opts.setTag(tag); |
1741 | |
1742 | // Use single flattened input tensor. |
1743 | at::Tensor flatInputTensor = flattenDenseTensors(inputs); |
1744 | GENERATE_ALL_TYPES(scalarType, setInput, opts, flatInputTensor); |
1745 | |
1746 | // Use single flat output tensor. |
1747 | // The first dimension corresponds to the index into outputs[N], |
1748 | // so copying into the actual output later is easy. |
1749 | at::Tensor flatOutputTensor = newLikeFlat(outputs[0]); |
1750 | GENERATE_ALL_TYPES(scalarType, setOutput, opts, flatOutputTensor); |
1751 | gloo::allgather(opts); |
1752 | |
1753 | // Unflatten into output tensors. |
1754 | for (auto& outputgroup : outputs) { |
1755 | for (const auto j : c10::irange(outputgroup.size())) { |
1756 | outputgroup[j].copy_(flatOutputTensor[j]); |
1757 | } |
1758 | } |
1759 | } |
1760 | |
1761 | void run() override { |
1762 | allgather(outputs, inputs); |
1763 | } |
1764 | }; |
1765 | |
1766 | // Note: current CUDA implementation holds the assumption that the |
1767 | // tensors in the nested output tensor vectors are on the same device. |
1768 | class AsyncAllgatherCUDAWork : public AsyncAllgatherWork { |
1769 | public: |
1770 | AsyncAllgatherCUDAWork( |
1771 | const std::shared_ptr<gloo::Context>& context, |
1772 | std::vector<std::vector<at::Tensor>>& outputs, |
1773 | std::vector<at::Tensor>& inputs, |
1774 | uint32_t tag) |
1775 | : AsyncAllgatherWork(context, outputs, inputs, tag) { |
1776 | initializeStreamsEvents(inputs, inputStreams, inputEvents); |
1777 | initializeStreamsEvents(outputs, outputStreams, outputEvents); |
1778 | |
1779 | // Kick off copy from CUDA tensors to pinned CPU tensors. |
1780 | tmpInputs.reserve(inputs.size()); |
1781 | c10::OptionalStreamGuard guard; |
1782 | for (const auto i : c10::irange(inputs.size())) { |
1783 | guard.reset_stream(inputStreams[i]); |
1784 | tmpInputs.push_back(pinnedLike(inputs[i]).copy_(inputs[i], true)); |
1785 | } |
1786 | |
1787 | tmpOutputs.resize(outputs.size()); |
1788 | for (const auto i : c10::irange(outputs.size())) { |
1789 | tmpOutputs[i].reserve(outputs[i].size()); |
1790 | for (const auto j : c10::irange(outputs[i].size())) { |
1791 | tmpOutputs[i].push_back(pinnedLike(outputs[i][j])); |
1792 | } |
1793 | } |
1794 | } |
1795 | |
1796 | void run() override { |
1797 | // Synchronize with copy operations. |
1798 | for (const auto i : c10::irange(inputs.size())) { |
1799 | inputStreams[i].synchronize(); |
1800 | } |
1801 | |
1802 | for (const auto i : c10::irange(outputs.size())) { |
1803 | outputStreams[i].synchronize(); |
1804 | } |
1805 | |
1806 | // Run allgather on host side tensors. |
1807 | allgather(tmpOutputs, tmpInputs); |
1808 | |
1809 | // Kick off copy back to the CUDA tensors. |
1810 | c10::OptionalStreamGuard guard; |
1811 | for (const auto i : c10::irange(outputs.size())) { |
1812 | guard.reset_stream(outputStreams[i]); |
1813 | for (const auto j : c10::irange(outputs[i].size())) { |
1814 | outputs[i][j].copy_(tmpOutputs[i][j], /* non_blocking */ true); |
1815 | } |
1816 | outputEvents[i].record(outputStreams[i]); |
1817 | } |
1818 | } |
1819 | |
1820 | void synchronize() override { |
1821 | // Synchronize with the copy back to CUDA tensors. |
1822 | for (const auto i : c10::irange(outputs.size())) { |
1823 | c10::Device device = outputs[i][0].device(); |
1824 | outputEvents[i].block( |
1825 | c10::impl::VirtualGuardImpl(device.type()).getStream(device)); |
1826 | } |
1827 | } |
1828 | |
1829 | std::vector<at::Tensor> tmpInputs; |
1830 | std::vector<c10::Stream> inputStreams; |
1831 | std::vector<c10::Event> inputEvents; |
1832 | |
1833 | std::vector<std::vector<at::Tensor>> tmpOutputs; |
1834 | std::vector<c10::Stream> outputStreams; |
1835 | std::vector<c10::Event> outputEvents; |
1836 | }; |
1837 | |
1838 | } // namespace |
1839 | |
1840 | // Note: current CUDA implementation holds the assumption that the |
1841 | // tensors in the nested output tensor vectors are on the same device. |
1842 | c10::intrusive_ptr<Work> ProcessGroupGloo::allgather( |
1843 | std::vector<std::vector<at::Tensor>>& outputs, |
1844 | std::vector<at::Tensor>& inputs, |
1845 | const AllgatherOptions& opts) { |
1846 | static auto invalidArgument = [](const std::string& msg) { |
1847 | TORCH_CHECK(false, "ProcessGroupGloo::allgather: " + msg); |
1848 | }; |
1849 | |
1850 | if (inputs.empty()) { |
1851 | invalidArgument("requires non-empty input tensor list" ); |
1852 | } |
1853 | |
1854 | if (inputs.size() != outputs.size()) { |
1855 | invalidArgument( |
1856 | "requires input/output tensor lists to have the same length" ); |
1857 | } |
1858 | |
1859 | for (const auto i : c10::irange(outputs.size())) { |
1860 | const auto expected = inputs.size() * getSize(); |
1861 | const auto actual = outputs[i].size(); |
1862 | if (actual != expected) { |
1863 | invalidArgument( |
1864 | "invalid output tensor list at index " + std::to_string(i) + |
1865 | " (expected length " + std::to_string(expected) + ", got " + |
1866 | std::to_string(actual) + ")" ); |
1867 | } |
1868 | } |
1869 | |
1870 | assertDense(invalidArgument, inputs); |
1871 | |
1872 | // Expect all input/output tensors to have the same type and sizes |
1873 | const auto& options = inputs[0].options(); |
1874 | const auto& sizes = inputs[0].sizes(); |
1875 | assertTypeAndSizesMatch(invalidArgument, inputs, options, sizes); |
1876 | for (const auto& output : outputs) { |
1877 | assertTypeAndSizesMatch(invalidArgument, output, options, sizes); |
1878 | } |
1879 | |
1880 | const auto& device = inputs[0].device(); |
1881 | switch (device.type()) { |
1882 | case at::kCPU: |
1883 | break; |
1884 | case at::kCUDA: |
1885 | // If the user gave us a CUDA tensor then CUDA must be loaded. |
1886 | TORCH_INTERNAL_ASSERT(at::hasCUDA()); |
1887 | break; |
1888 | default: |
1889 | invalidArgument(c10::str("unsupported device type " , device.type())); |
1890 | } |
1891 | |
1892 | c10::intrusive_ptr<AsyncAllgatherWork> work; |
1893 | auto tag = nextTag(); |
1894 | auto context = getContext(tag); |
1895 | if (device.type() == at::kCPU) { |
1896 | work = c10::make_intrusive<AsyncAllgatherWork>( |
1897 | std::move(context), outputs, inputs, tag); |
1898 | } else if (device.type() == at::kCUDA) { |
1899 | work = c10::make_intrusive<AsyncAllgatherCUDAWork>( |
1900 | std::move(context), outputs, inputs, tag); |
1901 | } else { |
1902 | TORCH_CHECK(false, "Invalid backend" ); |
1903 | } |
1904 | enqueue(work); |
1905 | return work; |
1906 | } |
1907 | |
1908 | namespace { |
1909 | |
1910 | class AsyncAllgatherCoalescedWork : public ProcessGroupGloo::AsyncWork { |
1911 | public: |
1912 | AsyncAllgatherCoalescedWork( |
1913 | const std::shared_ptr<gloo::Context>& context, |
1914 | std::vector<std::vector<at::Tensor>>& output_lists, |
1915 | std::vector<at::Tensor>& input_list, |
1916 | uint32_t tag) |
1917 | : ProcessGroupGloo::AsyncWork( |
1918 | output_lists, |
1919 | "gloo:all_gather" , |
1920 | input_list), |
1921 | context(context), |
1922 | output_lists(output_lists), |
1923 | input_list(input_list), |
1924 | tag(tag) {} |
1925 | |
1926 | std::shared_ptr<gloo::Context> context; |
1927 | std::vector<std::vector<at::Tensor>> output_lists; |
1928 | std::vector<at::Tensor> input_list; |
1929 | const uint32_t tag; |
1930 | |
1931 | void allgather_coalesced() { |
1932 | assert(!output_lists.empty()); |
1933 | assert(!output_lists[0].empty()); |
1934 | assert(!input_list.empty()); |
1935 | |
1936 | const auto& scalarType = input_list[0].scalar_type(); |
1937 | gloo::AllgatherOptions opts(context); |
1938 | opts.setTag(tag); |
1939 | |
1940 | // Use single flattened input tensor. |
1941 | at::Tensor flatInputTensor = flattenDenseTensors(input_list); |
1942 | GENERATE_ALL_TYPES(scalarType, setInput, opts, flatInputTensor); |
1943 | |
1944 | // Compute total number of elements we need to allocate for all tensors |
1945 | // requested. |
1946 | int64_t output_numel = 0; |
1947 | for (const auto& t : output_lists[0]) { |
1948 | output_numel += t.numel(); |
1949 | } |
1950 | output_numel *= output_lists.size(); |
1951 | // Use single flat output tensor. |
1952 | at::Tensor flatOutputTensor = |
1953 | at::empty({output_numel}, output_lists[0][0].options()); |
1954 | GENERATE_ALL_TYPES(scalarType, setOutput, opts, flatOutputTensor); |
1955 | gloo::allgather(opts); |
1956 | |
1957 | int64_t current_element = 0; |
1958 | for (auto& output_list : output_lists) { |
1959 | for (auto& output_tensor : output_list) { |
1960 | output_tensor.copy_( |
1961 | flatOutputTensor.narrow(0, current_element, output_tensor.numel()) |
1962 | .reshape(output_tensor.sizes()), |
1963 | true); |
1964 | current_element += output_tensor.numel(); |
1965 | } |
1966 | } |
1967 | } |
1968 | |
1969 | void run() override { |
1970 | allgather_coalesced(); |
1971 | } |
1972 | }; |
1973 | |
1974 | } // namespace |
1975 | |
1976 | c10::intrusive_ptr<Work> ProcessGroupGloo::allgather_coalesced( |
1977 | std::vector<std::vector<at::Tensor>>& output_lists, |
1978 | std::vector<at::Tensor>& input_list, |
1979 | const AllgatherOptions& /* unused */) { |
1980 | static auto invalidArgument = [](const std::string& msg) { |
1981 | TORCH_CHECK(false, "ProcessGroupGloo::allgather_coalesced: " + msg); |
1982 | }; |
1983 | |
1984 | if (input_list.empty()) { |
1985 | invalidArgument("requires non-empty input tensor list" ); |
1986 | } |
1987 | |
1988 | if (output_lists.size() != getSize()) { |
1989 | invalidArgument("output lists should be equal to world size" ); |
1990 | } |
1991 | |
1992 | assertSameDevice(invalidArgument, input_list); |
1993 | |
1994 | // Expect i'th tensor of each list from 'output_lists' match i'th tensor |
1995 | // from 'input_list' in type and size. |
1996 | for (const auto& output_list : output_lists) { |
1997 | if (output_list.size() != input_list.size()) { |
1998 | invalidArgument( |
1999 | "invalid output size: (expected length " + |
2000 | std::to_string(input_list.size()) + ", got " + |
2001 | std::to_string(output_list.size()) + ")" ); |
2002 | } |
2003 | for (const auto i : c10::irange(output_list.size())) { |
2004 | const auto expected = input_list[i].sizes(); |
2005 | const auto actual = output_list[i].sizes(); |
2006 | if (actual != expected) { |
2007 | invalidArgument( |
2008 | "invalid size of output tensor at index " + std::to_string(i) + |
2009 | " (expected length " + toString(expected) + ", got " + |
2010 | toString(actual) + ")" ); |
2011 | } |
2012 | if (!input_list[i].options().type_equal(output_list[i].options())) { |
2013 | invalidArgument( |
2014 | "invalid tensor type at index " + std::to_string(i) + |
2015 | " (expected " + input_list[i].toString() + ", got " + |
2016 | output_list[i].toString() + ")" ); |
2017 | } |
2018 | } |
2019 | } |
2020 | |
2021 | assertDense(invalidArgument, input_list); |
2022 | |
2023 | auto tag = nextTag(); |
2024 | auto context = getContext(tag); |
2025 | auto work = c10::make_intrusive<AsyncAllgatherCoalescedWork>( |
2026 | std::move(context), output_lists, input_list, tag); |
2027 | enqueue(work); |
2028 | return work; |
2029 | } |
2030 | |
2031 | c10::intrusive_ptr<Work> ProcessGroupGloo::_allgather_base( |
2032 | at::Tensor& /*unused */, |
2033 | at::Tensor& /*unused */, |
2034 | const AllgatherOptions& /*unused */) { |
2035 | TORCH_CHECK(false, "no support for _allgather_base in Gloo process group" ); |
2036 | } |
2037 | |
2038 | namespace { |
2039 | |
2040 | class AsyncGatherWork : public ProcessGroupGloo::AsyncWork { |
2041 | public: |
2042 | AsyncGatherWork( |
2043 | const std::shared_ptr<gloo::Context>& context, |
2044 | std::vector<std::vector<at::Tensor>>& outputs, |
2045 | std::vector<at::Tensor>& inputs, |
2046 | int root, |
2047 | uint32_t tag) |
2048 | : ProcessGroupGloo::AsyncWork(outputs, "gloo:gather" , inputs), |
2049 | context(context), |
2050 | outputs(outputs), |
2051 | inputs(inputs), |
2052 | root(root), |
2053 | tag(tag) {} |
2054 | |
2055 | std::shared_ptr<gloo::Context> context; |
2056 | std::vector<std::vector<at::Tensor>> outputs; |
2057 | std::vector<at::Tensor> inputs; |
2058 | const int root; |
2059 | const uint32_t tag; |
2060 | |
2061 | void gather( |
2062 | std::vector<std::vector<at::Tensor>>& outputs, |
2063 | std::vector<at::Tensor>& inputs) { |
2064 | const auto scalarType = inputs[0].scalar_type(); |
2065 | gloo::GatherOptions opts(context); |
2066 | opts.setRoot(root); |
2067 | opts.setTag(tag); |
2068 | |
2069 | // Set single temporary tensor on root process. |
2070 | // This is later scattered to the separate output tensors. |
2071 | at::Tensor flatOutputTensor; |
2072 | if (context->rank == root) { |
2073 | flatOutputTensor = newLikeFlat(outputs[0]); |
2074 | GENERATE_ALL_TYPES(scalarType, setOutput, opts, flatOutputTensor); |
2075 | } |
2076 | |
2077 | // Set single input tensor on all processes. |
2078 | GENERATE_ALL_TYPES(scalarType, setInput, opts, inputs[0]); |
2079 | gloo::gather(opts); |
2080 | |
2081 | // Unflatten into output tensors on root process. |
2082 | if (context->rank == root) { |
2083 | for (const auto i : c10::irange(outputs[0].size())) { |
2084 | outputs[0][i].copy_(flatOutputTensor[i]); |
2085 | } |
2086 | } |
2087 | } |
2088 | |
2089 | void run() override { |
2090 | gather(outputs, inputs); |
2091 | } |
2092 | }; |
2093 | |
2094 | // Note: current CUDA implementation holds the assumptions: |
2095 | // - inputs.size() is 1 |
2096 | // - outputs.size() is 1 |
2097 | // - the size of the nested output tensors is world size, i.e., |
2098 | // outputs[0].size, is world size |
2099 | class AsyncGatherCUDAWork : public AsyncGatherWork { |
2100 | public: |
2101 | AsyncGatherCUDAWork( |
2102 | const std::shared_ptr<gloo::Context>& context, |
2103 | std::vector<std::vector<at::Tensor>>& outputs, |
2104 | std::vector<at::Tensor>& inputs, |
2105 | int root, |
2106 | uint32_t tag) |
2107 | : AsyncGatherWork(context, outputs, inputs, root, tag) { |
2108 | initializeStreamsEvents(inputs, inputStreams, inputEvents); |
2109 | initializeStreamsEvents(outputs, outputStreams, outputEvents); |
2110 | |
2111 | // Kick off copy from CUDA tensors to pinned CPU tensors. |
2112 | tmpInputs.reserve(inputs.size()); |
2113 | c10::OptionalStreamGuard guard; |
2114 | for (const auto i : c10::irange(inputs.size())) { |
2115 | guard.reset_stream(inputStreams[i]); |
2116 | tmpInputs.push_back(pinnedLike(inputs[i]).copy_(inputs[i], true)); |
2117 | } |
2118 | |
2119 | tmpOutputs.resize(outputs.size()); |
2120 | for (const auto i : c10::irange(outputs.size())) { |
2121 | tmpOutputs[i].reserve(outputs[i].size()); |
2122 | for (const auto j : c10::irange(outputs[i].size())) { |
2123 | tmpOutputs[i].push_back(pinnedLike(outputs[i][j])); |
2124 | } |
2125 | } |
2126 | } |
2127 | |
2128 | void run() override { |
2129 | // Synchronize with copy operations. |
2130 | for (const auto i : c10::irange(inputs.size())) { |
2131 | inputStreams[i].synchronize(); |
2132 | } |
2133 | |
2134 | for (const auto i : c10::irange(outputs.size())) { |
2135 | outputStreams[i].synchronize(); |
2136 | } |
2137 | |
2138 | // Run gather on host side tensors. |
2139 | gather(tmpOutputs, tmpInputs); |
2140 | |
2141 | // Kick off copy back to the CUDA tensors. |
2142 | c10::OptionalStreamGuard guard; |
2143 | for (const auto i : c10::irange(outputs.size())) { |
2144 | guard.reset_stream(outputStreams[i]); |
2145 | for (const auto j : c10::irange(outputs[i].size())) { |
2146 | outputs[i][j].copy_(tmpOutputs[i][j], /* non_blocking */ true); |
2147 | } |
2148 | outputEvents[i].record(outputStreams[i]); |
2149 | } |
2150 | } |
2151 | |
2152 | void synchronize() override { |
2153 | // Synchronize with the copy back to CUDA tensors. |
2154 | for (const auto i : c10::irange(outputs.size())) { |
2155 | c10::Device device = outputs[i][0].device(); |
2156 | outputEvents[i].block( |
2157 | c10::impl::VirtualGuardImpl(device.type()).getStream(device)); |
2158 | } |
2159 | } |
2160 | |
2161 | std::vector<at::Tensor> tmpInputs; |
2162 | std::vector<c10::Stream> inputStreams; |
2163 | std::vector<c10::Event> inputEvents; |
2164 | |
2165 | std::vector<std::vector<at::Tensor>> tmpOutputs; |
2166 | std::vector<c10::Stream> outputStreams; |
2167 | std::vector<c10::Event> outputEvents; |
2168 | }; |
2169 | |
2170 | } // namespace |
2171 | |
2172 | c10::intrusive_ptr<Work> ProcessGroupGloo::gather( |
2173 | std::vector<std::vector<at::Tensor>>& outputs, |
2174 | std::vector<at::Tensor>& inputs, |
2175 | const GatherOptions& opts) { |
2176 | static auto invalidArgument = [](const std::string& msg) { |
2177 | TORCH_CHECK(false, "ProcessGroupGloo::gather: " + msg); |
2178 | }; |
2179 | |
2180 | assertRootRank(invalidArgument, opts.rootRank, size_); |
2181 | assertSingleElementInput(invalidArgument, inputs); |
2182 | assertDense(invalidArgument, inputs); |
2183 | |
2184 | if (getRank() == opts.rootRank) { |
2185 | if (outputs.size() != 1) { |
2186 | std::stringstream ss; |
2187 | ss << "requires a single-element output list containing a list with " |
2188 | << getSize() << " tensors." ; |
2189 | invalidArgument(ss.str()); |
2190 | } else if (outputs[0].size() != static_cast<size_t>(getSize())) { |
2191 | std::stringstream ss; |
2192 | ss << "Incorrect output list size " << outputs[0].size() |
2193 | << ". Output list size should be " << getSize() |
2194 | << ", same as size of the process group." ; |
2195 | invalidArgument(ss.str()); |
2196 | } |
2197 | |
2198 | const auto& options = inputs[0].options(); |
2199 | const auto& sizes = inputs[0].sizes(); |
2200 | assertTypeAndSizesMatch(invalidArgument, outputs[0], options, sizes); |
2201 | } else { |
2202 | if (!outputs.empty()) { |
2203 | invalidArgument("requires empty output on non-root" ); |
2204 | } |
2205 | } |
2206 | |
2207 | const auto& device = inputs[0].device(); |
2208 | switch (device.type()) { |
2209 | case at::kCPU: |
2210 | break; |
2211 | case at::kCUDA: |
2212 | // If the user gave us a CUDA tensor then CUDA must be loaded. |
2213 | TORCH_INTERNAL_ASSERT(at::hasCUDA()); |
2214 | break; |
2215 | default: |
2216 | invalidArgument(c10::str("unsupported device type " , device.type())); |
2217 | } |
2218 | |
2219 | c10::intrusive_ptr<AsyncGatherWork> work; |
2220 | auto tag = nextTag(); |
2221 | auto context = getContext(tag); |
2222 | if (device.type() == at::kCPU) { |
2223 | work = c10::make_intrusive<AsyncGatherWork>( |
2224 | std::move(context), outputs, inputs, opts.rootRank, tag); |
2225 | } else if (device.type() == at::kCUDA) { |
2226 | work = c10::make_intrusive<AsyncGatherCUDAWork>( |
2227 | std::move(context), outputs, inputs, opts.rootRank, tag); |
2228 | } else { |
2229 | TORCH_CHECK(false, "Invalid backend" ); |
2230 | } |
2231 | enqueue(work); |
2232 | return work; |
2233 | } |
2234 | |
2235 | namespace { |
2236 | |
2237 | class AsyncScatterWork : public ProcessGroupGloo::AsyncWork { |
2238 | public: |
2239 | AsyncScatterWork( |
2240 | const std::shared_ptr<gloo::Context>& context, |
2241 | std::vector<at::Tensor>& outputs, |
2242 | std::vector<std::vector<at::Tensor>>& inputs, |
2243 | int root, |
2244 | uint32_t tag) |
2245 | : ProcessGroupGloo::AsyncWork( |
2246 | {outputs}, |
2247 | "gloo:scatter" , |
2248 | !inputs.empty() ? c10::optional<std::vector<at::Tensor>>(inputs[0]) |
2249 | : c10::nullopt), |
2250 | context(context), |
2251 | outputs(outputs), |
2252 | inputs(inputs), |
2253 | root(root), |
2254 | tag(tag) {} |
2255 | |
2256 | std::shared_ptr<gloo::Context> context; |
2257 | std::vector<at::Tensor> outputs; |
2258 | std::vector<std::vector<at::Tensor>> inputs; |
2259 | const int root; |
2260 | const uint32_t tag; |
2261 | |
2262 | void scatter( |
2263 | std::vector<at::Tensor>& outputs, |
2264 | std::vector<std::vector<at::Tensor>>& inputs) { |
2265 | const auto scalarType = outputs[0].scalar_type(); |
2266 | gloo::ScatterOptions opts(context); |
2267 | opts.setRoot(root); |
2268 | opts.setTag(tag); |
2269 | |
2270 | // Set list of input tensors on root process |
2271 | if (context->rank == root) { |
2272 | GENERATE_ALL_TYPES(scalarType, setInputs, opts, inputs[0]); |
2273 | } |
2274 | |
2275 | // Set single output tensor on all processes |
2276 | GENERATE_ALL_TYPES(scalarType, setOutput, opts, outputs[0]); |
2277 | gloo::scatter(opts); |
2278 | } |
2279 | |
2280 | void run() override { |
2281 | scatter(outputs, inputs); |
2282 | } |
2283 | }; |
2284 | |
2285 | class AsyncScatterCUDAWork : public AsyncScatterWork { |
2286 | public: |
2287 | AsyncScatterCUDAWork( |
2288 | const std::shared_ptr<gloo::Context>& context, |
2289 | std::vector<at::Tensor>& outputs, |
2290 | std::vector<std::vector<at::Tensor>>& inputs, |
2291 | int root, |
2292 | uint32_t tag) |
2293 | : AsyncScatterWork(context, outputs, inputs, root, tag) { |
2294 | initializeStreamsEvents(inputs, inputStreams, inputEvents); |
2295 | initializeStreamsEvents(outputs, outputStreams, outputEvents); |
2296 | |
2297 | // Kick off copy from CUDA tensors to pinned CPU tensors. |
2298 | tmpInputs.resize(inputs.size()); |
2299 | c10::OptionalStreamGuard guard; |
2300 | for (const auto i : c10::irange(inputs.size())) { |
2301 | guard.reset_stream(inputStreams[i]); |
2302 | tmpInputs[i].reserve(inputs[i].size()); |
2303 | for (const auto j : c10::irange(inputs[i].size())) { |
2304 | tmpInputs[i].push_back( |
2305 | pinnedLike(inputs[i][j]).copy_(inputs[i][j], true)); |
2306 | } |
2307 | } |
2308 | |
2309 | tmpOutputs.reserve(outputs.size()); |
2310 | for (auto& output : outputs) { |
2311 | tmpOutputs.push_back(pinnedLike(output)); |
2312 | } |
2313 | } |
2314 | |
2315 | void run() override { |
2316 | // Synchronize with copy operations. |
2317 | for (const auto i : c10::irange(inputs.size())) { |
2318 | inputStreams[i].synchronize(); |
2319 | } |
2320 | for (const auto i : c10::irange(outputs.size())) { |
2321 | outputStreams[i].synchronize(); |
2322 | } |
2323 | |
2324 | // Run scatter on host side tensors. |
2325 | scatter(tmpOutputs, tmpInputs); |
2326 | |
2327 | // Kick off copy back to the CUDA tensors. |
2328 | c10::OptionalStreamGuard guard; |
2329 | for (const auto i : c10::irange(outputs.size())) { |
2330 | guard.reset_stream(outputStreams[i]); |
2331 | outputs[i].copy_(tmpOutputs[i], /* non_blocking */ true); |
2332 | outputEvents[i].record(outputStreams[i]); |
2333 | } |
2334 | } |
2335 | |
2336 | void synchronize() override { |
2337 | // Synchronize with the copy back to CUDA tensors. |
2338 | for (const auto i : c10::irange(outputs.size())) { |
2339 | c10::Device device = outputs[i].device(); |
2340 | outputEvents[i].block( |
2341 | c10::impl::VirtualGuardImpl(device.type()).getStream(device)); |
2342 | } |
2343 | } |
2344 | |
2345 | std::vector<at::Tensor> tmpOutputs; |
2346 | std::vector<c10::Stream> outputStreams; |
2347 | std::vector<c10::Event> outputEvents; |
2348 | |
2349 | std::vector<std::vector<at::Tensor>> tmpInputs; |
2350 | std::vector<c10::Stream> inputStreams; |
2351 | std::vector<c10::Event> inputEvents; |
2352 | }; |
2353 | |
2354 | } // namespace |
2355 | |
2356 | c10::intrusive_ptr<Work> ProcessGroupGloo::scatter( |
2357 | std::vector<at::Tensor>& outputs, |
2358 | std::vector<std::vector<at::Tensor>>& inputs, |
2359 | const ScatterOptions& opts) { |
2360 | static auto invalidArgument = [](const std::string& msg) { |
2361 | TORCH_CHECK(false, "ProcessGroupGloo::scatter: " + msg); |
2362 | }; |
2363 | |
2364 | assertRootRank(invalidArgument, opts.rootRank, size_); |
2365 | assertSingleElementOutput(invalidArgument, outputs); |
2366 | assertDense(invalidArgument, outputs); |
2367 | |
2368 | if (getRank() == opts.rootRank) { |
2369 | if (inputs.size() != 1) { |
2370 | std::stringstream ss; |
2371 | ss << "requires a single-element input list containing a list with " |
2372 | << getSize() << " tensors" ; |
2373 | invalidArgument(ss.str()); |
2374 | } else if (inputs[0].size() != static_cast<size_t>(getSize())) { |
2375 | std::stringstream ss; |
2376 | ss << "Incorrect input list size " << inputs[0].size() |
2377 | << ". Input list size should be " << getSize() |
2378 | << ", same as size of the process group." ; |
2379 | invalidArgument(ss.str()); |
2380 | } |
2381 | const auto& options = outputs[0].options(); |
2382 | const auto& sizes = outputs[0].sizes(); |
2383 | assertTypeAndSizesMatch(invalidArgument, inputs[0], options, sizes); |
2384 | } else { |
2385 | if (!inputs.empty()) { |
2386 | invalidArgument("requires empty input on non-root" ); |
2387 | } |
2388 | } |
2389 | |
2390 | const auto& device = outputs[0].device(); |
2391 | switch (device.type()) { |
2392 | case at::kCPU: |
2393 | break; |
2394 | case at::kCUDA: |
2395 | // If the user gave us a CUDA tensor then CUDA must be loaded. |
2396 | TORCH_INTERNAL_ASSERT(at::hasCUDA()); |
2397 | break; |
2398 | default: |
2399 | invalidArgument(c10::str("unsupported device type " , device.type())); |
2400 | } |
2401 | |
2402 | c10::intrusive_ptr<AsyncScatterWork> work; |
2403 | auto tag = nextTag(); |
2404 | auto context = getContext(tag); |
2405 | if (device.type() == at::kCPU) { |
2406 | work = c10::make_intrusive<AsyncScatterWork>( |
2407 | std::move(context), outputs, inputs, opts.rootRank, tag); |
2408 | } else if (device.type() == at::kCUDA) { |
2409 | work = c10::make_intrusive<AsyncScatterCUDAWork>( |
2410 | std::move(context), outputs, inputs, opts.rootRank, tag); |
2411 | } else { |
2412 | TORCH_CHECK(false, "Invalid backend" ); |
2413 | } |
2414 | enqueue(work); |
2415 | return work; |
2416 | } |
2417 | |
2418 | c10::intrusive_ptr<Work> ProcessGroupGloo::reduce_scatter( |
2419 | std::vector<at::Tensor>& outputs, |
2420 | std::vector<std::vector<at::Tensor>>& inputs, |
2421 | const ReduceScatterOptions& opts) { |
2422 | TORCH_CHECK(false, "ProcessGroupGloo does not support reduce_scatter" ); |
2423 | } |
2424 | |
2425 | namespace { |
2426 | |
2427 | class AsyncAlltoallWork : public ProcessGroupGloo::AsyncWork { |
2428 | public: |
2429 | AsyncAlltoallWork( |
2430 | const std::shared_ptr<gloo::Context>& context, |
2431 | at::Tensor& outputTensor, |
2432 | at::Tensor& inputTensor, |
2433 | std::vector<int64_t>& outputCounts, |
2434 | std::vector<int64_t>& inputCounts, |
2435 | uint32_t tag) |
2436 | : ProcessGroupGloo::AsyncWork( |
2437 | {{outputTensor}}, |
2438 | "gloo:all_to_all" , |
2439 | c10::optional<std::vector<at::Tensor>>({inputTensor})), |
2440 | context(context), |
2441 | outputTensor(outputTensor), |
2442 | inputTensor(inputTensor), |
2443 | outputCounts(std::move(outputCounts)), |
2444 | inputCounts(std::move(inputCounts)), |
2445 | tag(tag) {} |
2446 | |
2447 | std::shared_ptr<gloo::Context> context; |
2448 | at::Tensor outputTensor; |
2449 | at::Tensor inputTensor; |
2450 | std::vector<int64_t> outputCounts; |
2451 | std::vector<int64_t> inputCounts; |
2452 | const uint32_t tag; |
2453 | |
2454 | void alltoall(at::Tensor& outputTensor, at::Tensor& inputTensor) { |
2455 | const auto scalarType = outputTensor.scalar_type(); |
2456 | if (outputCounts.empty() && inputCounts.empty()) { |
2457 | // Gloo alltoall |
2458 | gloo::AlltoallOptions opts(context); |
2459 | opts.setTag(tag); |
2460 | GENERATE_ALL_TYPES(scalarType, setInput, opts, inputTensor); |
2461 | GENERATE_ALL_TYPES(scalarType, setOutput, opts, outputTensor); |
2462 | gloo::alltoall(opts); |
2463 | } else { |
2464 | // Gloo alltoallv |
2465 | c10d::checkSplitSizes(inputCounts, inputTensor, context->size); |
2466 | c10d::checkSplitSizes(outputCounts, outputTensor, context->size); |
2467 | std::vector<int64_t> sendCounts(context->size); |
2468 | std::vector<int64_t> recvCounts(context->size); |
2469 | std::vector<int64_t> sendOffsets(context->size); |
2470 | std::vector<int64_t> recvOffsets(context->size); |
2471 | c10d::computeLengthsAndOffsets( |
2472 | inputCounts, inputTensor, &sendCounts, &sendOffsets); |
2473 | c10d::computeLengthsAndOffsets( |
2474 | outputCounts, outputTensor, &recvCounts, &recvOffsets); |
2475 | gloo::AlltoallvOptions opts(context); |
2476 | opts.setTag(tag); |
2477 | GENERATE_ALL_TYPES(scalarType, setInput, opts, inputTensor, sendCounts); |
2478 | GENERATE_ALL_TYPES(scalarType, setOutput, opts, outputTensor, recvCounts); |
2479 | gloo::alltoallv(opts); |
2480 | } |
2481 | } |
2482 | |
2483 | void run() override { |
2484 | alltoall(outputTensor, inputTensor); |
2485 | } |
2486 | }; |
2487 | |
2488 | class AsyncAlltoallCUDAWork : public AsyncAlltoallWork { |
2489 | public: |
2490 | AsyncAlltoallCUDAWork( |
2491 | const std::shared_ptr<gloo::Context>& context, |
2492 | at::Tensor& outputTensor, |
2493 | at::Tensor& inputTensor, |
2494 | std::vector<int64_t>& outputCounts, |
2495 | std::vector<int64_t>& inputCounts, |
2496 | uint32_t tag) |
2497 | : AsyncAlltoallWork( |
2498 | context, |
2499 | outputTensor, |
2500 | inputTensor, |
2501 | outputCounts, |
2502 | inputCounts, |
2503 | tag) { |
2504 | initializeStreamsEvents({inputTensor}, inputStreams, inputEvents); |
2505 | initializeStreamsEvents({outputTensor}, outputStreams, outputEvents); |
2506 | |
2507 | // Kick off copy from CUDA tensors to pinned CPU tensors. |
2508 | c10::OptionalStreamGuard guard; |
2509 | guard.reset_stream(inputStreams.front()); |
2510 | cpuInput = pinnedLike(inputTensor).copy_(inputTensor, true); |
2511 | |
2512 | guard.reset_stream(outputStreams.front()); |
2513 | cpuOutput = pinnedLike(outputTensor); |
2514 | } |
2515 | |
2516 | void run() override { |
2517 | // Synchronize with copy operations. |
2518 | inputStreams.front().synchronize(); |
2519 | outputStreams.front().synchronize(); |
2520 | |
2521 | // Run alltoall on host side tensors. |
2522 | alltoall(cpuOutput, cpuInput); |
2523 | |
2524 | // Kick off copy back to the CUDA tensors. |
2525 | c10::OptionalStreamGuard guard; |
2526 | guard.reset_stream(outputStreams.front()); |
2527 | outputTensor.copy_(cpuOutput, /* non_blocking */ true); |
2528 | outputEvents.front().record(outputStreams.front()); |
2529 | } |
2530 | |
2531 | void synchronize() override { |
2532 | // Synchronize with the copy back to CUDA tensors. |
2533 | c10::Device device = outputTensor.device(); |
2534 | outputEvents.front().block( |
2535 | c10::impl::VirtualGuardImpl(device.type()).getStream(device)); |
2536 | } |
2537 | |
2538 | at::Tensor cpuOutput; |
2539 | std::vector<c10::Stream> outputStreams; |
2540 | std::vector<c10::Event> outputEvents; |
2541 | |
2542 | at::Tensor cpuInput; |
2543 | std::vector<c10::Stream> inputStreams; |
2544 | std::vector<c10::Event> inputEvents; |
2545 | }; |
2546 | |
2547 | } // namespace |
2548 | |
2549 | c10::intrusive_ptr<Work> ProcessGroupGloo::alltoall_base( |
2550 | at::Tensor& outputTensor, |
2551 | at::Tensor& inputTensor, |
2552 | std::vector<int64_t>& outputCounts, |
2553 | std::vector<int64_t>& inputCounts, |
2554 | const AllToAllOptions& /* unused */) { |
2555 | static auto invalidArgument = [](const std::string& msg) { |
2556 | TORCH_CHECK(false, "ProcessGroupGloo::alltoall_base: " + msg); |
2557 | }; |
2558 | |
2559 | TORCH_CHECK( |
2560 | outputTensor.device() == inputTensor.device(), |
2561 | "output tensor and input tensor must be on the same type of device" ); |
2562 | assertDense(invalidArgument, {outputTensor}); |
2563 | assertDense(invalidArgument, {inputTensor}); |
2564 | |
2565 | const auto& device = outputTensor.device(); |
2566 | c10::intrusive_ptr<AsyncAlltoallWork> work; |
2567 | auto tag = nextTag(); |
2568 | auto context = getContext(tag); |
2569 | |
2570 | if (device.type() == at::kCPU) { |
2571 | work = c10::make_intrusive<AsyncAlltoallWork>( |
2572 | std::move(context), |
2573 | outputTensor, |
2574 | inputTensor, |
2575 | outputCounts, |
2576 | inputCounts, |
2577 | tag); |
2578 | } else if (device.type() == at::kCUDA) { |
2579 | work = c10::make_intrusive<AsyncAlltoallCUDAWork>( |
2580 | std::move(context), |
2581 | outputTensor, |
2582 | inputTensor, |
2583 | outputCounts, |
2584 | inputCounts, |
2585 | tag); |
2586 | } else { |
2587 | invalidArgument(c10::str("unsupported device type " , device.type())); |
2588 | } |
2589 | enqueue(work); |
2590 | return work; |
2591 | } |
2592 | |
2593 | at::Tensor& checkSingleTensor(std::vector<at::Tensor>& tensors) { |
2594 | if (tensors.size() != 1) { |
2595 | TORCH_CHECK(false, "ProcessGroupGloo::send takes a single tensor" ); |
2596 | } |
2597 | auto& tensor = tensors[0]; |
2598 | if (!tensor.is_contiguous()) { |
2599 | TORCH_CHECK(false, "input tensor has to be contiguous" ); |
2600 | } |
2601 | if (tensor.is_sparse()) { |
2602 | TORCH_CHECK(false, "input tensor has to be dense" ); |
2603 | } |
2604 | return tensor; |
2605 | } |
2606 | |
2607 | uint32_t checkTag(int32_t tag) { |
2608 | TORCH_CHECK(tag >= 0, "Tag must be nonnegative" ); |
2609 | return (uint32_t)tag; |
2610 | } |
2611 | |
2612 | c10::intrusive_ptr<Work> ProcessGroupGloo::send( |
2613 | std::vector<at::Tensor>& tensors, |
2614 | int dstRank, |
2615 | int tag) { |
2616 | auto& tensor = checkSingleTensor(tensors); |
2617 | auto utag = checkTag(tag); |
2618 | auto ptr = tensor.data_ptr(); |
2619 | auto size = tensor.numel() * tensor.element_size(); |
2620 | |
2621 | // Construct unbound buffer. |
2622 | auto context = getContext(tag); |
2623 | auto buf = context->createUnboundBuffer(ptr, size); |
2624 | buf->send(dstRank, utag); |
2625 | |
2626 | // The work captures the tensor to prevent it being deallocated and |
2627 | // the unbound buffer to synchronize on completion of the send. |
2628 | return c10::make_intrusive<SendWork>(tensor, std::move(buf)); |
2629 | } |
2630 | |
2631 | c10::intrusive_ptr<Work> ProcessGroupGloo::recv( |
2632 | std::vector<at::Tensor>& tensors, |
2633 | int srcRank, |
2634 | int tag) { |
2635 | auto& tensor = checkSingleTensor(tensors); |
2636 | auto utag = checkTag(tag); |
2637 | auto ptr = tensor.data_ptr(); |
2638 | auto size = tensor.numel() * tensor.element_size(); |
2639 | |
2640 | // Construct unbound buffer. |
2641 | auto context = getContext(tag); |
2642 | auto buf = context->createUnboundBuffer(ptr, size); |
2643 | buf->recv(srcRank, utag); |
2644 | |
2645 | // The work captures the tensor to prevent it being deallocated and |
2646 | // the unbound buffer to synchronize on completion of the recv. |
2647 | return c10::make_intrusive<RecvWork>(tensor, std::move(buf), "gloo:recv" ); |
2648 | } |
2649 | |
2650 | c10::intrusive_ptr<Work> ProcessGroupGloo::recvAnysource( |
2651 | std::vector<at::Tensor>& tensors, |
2652 | int tag) { |
2653 | auto& tensor = checkSingleTensor(tensors); |
2654 | auto utag = checkTag(tag); |
2655 | auto ptr = tensor.data_ptr(); |
2656 | auto size = tensor.numel() * tensor.element_size(); |
2657 | |
2658 | // Construct unbound buffer. |
2659 | auto context = getContext(tag); |
2660 | auto buf = context->createUnboundBuffer(ptr, size); |
2661 | |
2662 | // Build list of ranks that this operation can recv from. In these |
2663 | // bindings we don't differentiate between ranks and can receive |
2664 | // from any other process in the group. |
2665 | std::vector<int> srcRanks; |
2666 | srcRanks.resize(size_); |
2667 | for (const auto i : c10::irange(size_)) { |
2668 | srcRanks.push_back(i); |
2669 | } |
2670 | |
2671 | buf->recv(srcRanks, utag); |
2672 | |
2673 | // The work captures the tensor to prevent it being deallocated and |
2674 | // the unbound buffer to synchronize on completion of the recv. |
2675 | return c10::make_intrusive<RecvWork>( |
2676 | tensor, std::move(buf), "gloo:recvAnySource" ); |
2677 | } |
2678 | |
2679 | namespace { |
2680 | |
2681 | class AsyncBarrierWork : public ProcessGroupGloo::AsyncWork { |
2682 | public: |
2683 | AsyncBarrierWork( |
2684 | const std::shared_ptr<gloo::Context>& context, |
2685 | std::vector<c10::weak_intrusive_ptr<AsyncWork>> priorWork, |
2686 | uint32_t tag) |
2687 | : ProcessGroupGloo::AsyncWork({}, "gloo:barrier" , c10::nullopt), |
2688 | context(context), |
2689 | priorWork(std::move(priorWork)), |
2690 | tag(tag) {} |
2691 | |
2692 | std::shared_ptr<gloo::Context> context; |
2693 | std::vector<c10::weak_intrusive_ptr<AsyncWork>> priorWork; |
2694 | const uint32_t tag; |
2695 | |
2696 | void run() override { |
2697 | // Wait on prior work to complete |
2698 | for (auto& weakWork : priorWork) { |
2699 | auto work = weakWork.lock(); |
2700 | if (work) { |
2701 | work->wait(); |
2702 | } |
2703 | } |
2704 | |
2705 | gloo::BarrierOptions opts(context); |
2706 | opts.setTag(tag); |
2707 | gloo::barrier(opts); |
2708 | } |
2709 | }; |
2710 | |
2711 | } // namespace |
2712 | |
2713 | c10::intrusive_ptr<Work> ProcessGroupGloo::barrier(const BarrierOptions& opts) { |
2714 | std::vector<c10::weak_intrusive_ptr<AsyncWork>> priorWork; |
2715 | |
2716 | // Snapshot all in progress and pending work as weak_ptr. |
2717 | // When executing a barrier, we need to ensure that all prior work |
2718 | // has completed before completing itself. |
2719 | { |
2720 | std::unique_lock<std::mutex> lock(workMutex_); |
2721 | priorWork.insert( |
2722 | priorWork.end(), workInProgress_.begin(), workInProgress_.end()); |
2723 | priorWork.insert(priorWork.end(), workQueue_.begin(), workQueue_.end()); |
2724 | } |
2725 | |
2726 | auto tag = nextTag(); |
2727 | auto context = getContext(tag); |
2728 | auto work = c10::make_intrusive<AsyncBarrierWork>( |
2729 | std::move(context), std::move(priorWork), tag); |
2730 | enqueue(work); |
2731 | return work; |
2732 | } |
2733 | |
2734 | void ProcessGroupGloo::monitoredBarrier( |
2735 | const BarrierOptions& opts, |
2736 | bool waitAllRanks) { |
2737 | C10_LOG_API_USAGE_ONCE("torch.distributed.monitored_barrier" ); |
2738 | // Use default timeout if no timeout was specified. |
2739 | auto monitoredBarrierTimeout = |
2740 | (opts.timeout == kUnsetTimeout) ? this->options_->timeout : opts.timeout; |
2741 | auto rank = this->getRank(); |
2742 | auto t1 = nextTag(); |
2743 | auto t2 = nextTag(); |
2744 | std::vector<at::Tensor> commTensor = {at::tensor({rank})}; |
2745 | // only enforce timeout on rank 0. This is so that other ranks aren't timed |
2746 | // out first, bringing down the job without reporting which rank timed out. |
2747 | if (rank != 0) { |
2748 | auto sendWork = send(commTensor, 0, t1); |
2749 | auto recvWork = recv(commTensor, 0, t2); |
2750 | try { |
2751 | sendWork->wait(); |
2752 | recvWork->wait(); |
2753 | } catch (const std::exception& e) { |
2754 | const std::string error = c10::str( |
2755 | "Rank " , |
2756 | rank, |
2757 | " successfully reached monitoredBarrier, but received errors while waiting" , |
2758 | " for send/recv from rank 0. Please check rank 0 logs for faulty rank." ); |
2759 | logAndThrow( |
2760 | error, c10::str(error, "\n Original exception: \n" , e.what())); |
2761 | } |
2762 | return; |
2763 | } |
2764 | auto startTime = std::chrono::steady_clock::now(); |
2765 | auto worldSize = this->getSize(); |
2766 | // Mappings of rank to recvWork/sendWork respectively. |
2767 | std::map<int, c10::intrusive_ptr<Work>> recvWorkMap; |
2768 | std::map<int, c10::intrusive_ptr<Work>> sendWorkMap; |
2769 | // Kick off recvWork and wait to unblock sendWork->wait() from non-zero ranks. |
2770 | // Failed/hanging ranks will not ack this call, letting rank 0 know about the |
2771 | // failure. |
2772 | for (const auto dstRank : c10::irange(1, worldSize)) { |
2773 | recvWorkMap.insert({dstRank, recv(commTensor, dstRank, t1)}); |
2774 | } |
2775 | |
2776 | auto waitLoop = [&](const std::map<int, c10::intrusive_ptr<Work>>& works) { |
2777 | std::vector<int> processedRanks; |
2778 | for (auto& work : works) { |
2779 | bool rankResponded = false; |
2780 | try { |
2781 | // Note: if waitAllRanks=false, we recompute the time remaining in |
2782 | // barrier and use this recomputed time in wait(). However, if |
2783 | // waitAllRanks=true, we use the original timeout, since if we use |
2784 | // up the entire timeout waiting for response from rank n, then we |
2785 | // won't have any timeout left to query ranks beginning with n + 1. |
2786 | auto remainingTime = |
2787 | getRemainingTime(startTime, monitoredBarrierTimeout, waitAllRanks); |
2788 | if (!waitAllRanks) { |
2789 | checkRemainingTime( |
2790 | monitoredBarrierTimeout, remainingTime, processedRanks, rank); |
2791 | } |
2792 | work.second->wait(remainingTime); |
2793 | rankResponded = true; |
2794 | } catch (const std::exception& e) { |
2795 | const std::string error = c10::str( |
2796 | "[Rank 0]: Rank " , |
2797 | work.first, |
2798 | " failed to pass monitoredBarrier in " , |
2799 | monitoredBarrierTimeout.count(), |
2800 | " ms" ); |
2801 | if (waitAllRanks) { |
2802 | LOG(ERROR) << error; |
2803 | } else { |
2804 | logAndThrow( |
2805 | error, c10::str(error, "\n Original exception: \n" , e.what())); |
2806 | } |
2807 | } |
2808 | if (rankResponded) { |
2809 | processedRanks.push_back(work.first); |
2810 | } |
2811 | } |
2812 | // If we are collecting all failed ranks, check if we need to throw if |
2813 | // some ranks have not responded. |
2814 | // Ensure all ranks from 1, ... WORLD_SIZE -1 have been successfully |
2815 | // processed. |
2816 | auto rankFailure = (processedRanks.size() != size_ - 1); |
2817 | if (waitAllRanks && rankFailure) { |
2818 | std::vector<int> failedRanks; |
2819 | for (const auto i : c10::irange(1, size_)) { |
2820 | if (std::find(processedRanks.begin(), processedRanks.end(), i) == |
2821 | processedRanks.end()) { |
2822 | failedRanks.push_back(i); |
2823 | } |
2824 | } |
2825 | |
2826 | TORCH_INTERNAL_ASSERT(!failedRanks.empty()); |
2827 | const std::string ranksStr = c10::Join(", " , failedRanks); |
2828 | const std::string error = c10::str( |
2829 | "[Rank 0]: Ranks " , |
2830 | ranksStr, |
2831 | " failed to pass monitoredBarrier in " , |
2832 | monitoredBarrierTimeout.count(), |
2833 | " ms" ); |
2834 | logAndThrow(error, error); |
2835 | } |
2836 | }; |
2837 | |
2838 | waitLoop(recvWorkMap); |
2839 | // If we've reached here successfully, this means all ranks have acked in |
2840 | // monitoredBarrier. Unblock all ranks now by responding to their recv(). This |
2841 | // ensures that this is a true barrier in that all ranks exit it successfully |
2842 | // or none of them do. |
2843 | for (const auto dstRank : c10::irange(1, worldSize)) { |
2844 | sendWorkMap.insert({dstRank, send(commTensor, dstRank, t2)}); |
2845 | } |
2846 | |
2847 | waitLoop(sendWorkMap); |
2848 | } |
2849 | |
2850 | void ProcessGroupGloo::setSequenceNumberForGroup() { |
2851 | if (rank_ == 0) { |
2852 | // Create and broadcast sequence number |
2853 | auto seq = 1 + rand(); |
2854 | sequenceNum_ = c10d::SequenceNum(seq); |
2855 | std::vector<char> values = c10d::toVec<char>(seq, kBytes); |
2856 | store_->set(kSeqNumStoreKey, values); |
2857 | } else { |
2858 | // Read rank 0's sequence number from store. |
2859 | sequenceNum_ = c10d::SequenceNum(); |
2860 | store_->wait({kSeqNumStoreKey}, options_->timeout); |
2861 | std::vector<char> values = store_->get(kSeqNumStoreKey); |
2862 | uint64_t num = c10d::fromVec<char>(values); |
2863 | sequenceNum_->set(num); |
2864 | } |
2865 | } |
2866 | |
2867 | uint64_t ProcessGroupGloo::getSequenceNumberForGroup() { |
2868 | if (sequenceNum_ == c10::nullopt) { |
2869 | return 0; |
2870 | } |
2871 | return sequenceNum_->get(); |
2872 | } |
2873 | |
2874 | } // namespace c10d |
2875 | |
2876 | #endif // USE_C10D_GLOO |
2877 | |