1#include <torch/csrc/distributed/c10d/ProcessGroupMPI.hpp>
2
3#ifdef USE_C10D_MPI
4
5#include <iostream>
6#include <limits>
7#include <map>
8
9#include <c10/core/DeviceGuard.h>
10#include <c10/util/irange.h>
11
12#if defined(OPEN_MPI) && OPEN_MPI
13#include <mpi-ext.h> // Needed for CUDA-aware check
14#endif
15
16namespace c10d {
17
18#define MPI_CHECK(cmd) \
19 do { \
20 int mpiStatus = cmd; \
21 if (mpiStatus != MPI_SUCCESS) { \
22 std::string err = "MPI error in: " + std::string(__FILE__) + ":" + \
23 std::to_string(__LINE__) + \
24 ", with error code: " + std::to_string(mpiStatus); \
25 TORCH_CHECK(false, err); \
26 } \
27 } while (0)
28
29namespace {
30
31// Op mapping
32std::map<ReduceOp::RedOpType, MPI_Op> mpiOp = {
33 {ReduceOp::MIN, MPI_MIN},
34 {ReduceOp::MAX, MPI_MAX},
35 {ReduceOp::SUM, MPI_SUM},
36 {ReduceOp::PRODUCT, MPI_PROD},
37};
38// Type mapping
39std::map<at::ScalarType, MPI_Datatype> mpiDatatype = {
40 {at::kByte, MPI_UNSIGNED_CHAR},
41 {at::kChar, MPI_CHAR},
42 {at::kDouble, MPI_DOUBLE},
43 {at::kFloat, MPI_FLOAT},
44 {at::kInt, MPI_INT},
45 {at::kLong, MPI_LONG},
46 {at::kShort, MPI_SHORT},
47};
48
49// Checking CUDA-aware MPI support, currently we only support CUDA aware
50// MPI ops through Open MPI
51bool cudaAwareMpiCheck() {
52// Run time check
53#if defined(MPIX_CUDA_AWARE_SUPPORT)
54 if (MPIX_Query_cuda_support() == 1) {
55 return true;
56 } else {
57 return false;
58 }
59#else // !defined(MPIX_CUDA_AWARE_SUPPORT)
60 return false;
61#endif // MPIX_CUDA_AWARE_SUPPORT
62}
63
64// Checking the input tensor's validity
65void checkSingleTensorHelper(const at::Tensor& tensor) {
66 if (!tensor.is_contiguous()) {
67 TORCH_CHECK(false, "input tensor has to be contiguous");
68 }
69 if (tensor.is_sparse()) {
70 TORCH_CHECK(false, "input tensor has to be dense");
71 }
72 if (tensor.is_cuda() && !cudaAwareMpiCheck()) {
73 TORCH_CHECK(
74 false,
75 "CUDA tensor detected and the MPI used doesn't "
76 "have CUDA-aware MPI support");
77 }
78}
79
80void checkSingleTensor(const std::vector<at::Tensor>& tensors) {
81 if (tensors.size() != 1) {
82 TORCH_CHECK(
83 false, "MPI process group does not support multi-GPU collectives");
84 }
85 checkSingleTensorHelper(tensors[0]);
86}
87
88void checkSameSizeAndType(
89 const at::Tensor& t_in,
90 const std::vector<at::Tensor>& tensors) {
91 for (const auto& tensor : tensors) {
92 if ((tensor.numel() != t_in.numel()) ||
93 (tensor.scalar_type() != t_in.scalar_type())) {
94 TORCH_CHECK(false, "Tensors are not equal in size or data type");
95 }
96 checkSingleTensorHelper(tensor);
97 }
98}
99
100} // namespace
101
102std::vector<at::Tensor> ProcessGroupMPI::WorkMPI::result() {
103 return outputTensors_;
104}
105
106c10::intrusive_ptr<c10::ivalue::Future> ProcessGroupMPI::WorkMPI::getFuture() {
107 return future_;
108}
109
110void ProcessGroupMPI::WorkMPI::finishWorkMPIError(std::exception_ptr eptr) {
111 future_->setError(eptr);
112 finish(eptr);
113}
114
115void ProcessGroupMPI::WorkMPI::finishWorkMPI() {
116 future_->markCompleted(at::IValue(outputTensors_));
117 finish();
118}
119
120ProcessGroupMPI::AsyncWork::AsyncWork(
121 MPI_Request request,
122 std::vector<at::Tensor> outputTensors,
123 const char* profilingTitle,
124 const c10::optional<std::vector<at::Tensor>>& inputTensors)
125 : Work(-1, OpType::UNKNOWN, profilingTitle, inputTensors),
126 outputTensors_(std::move(outputTensors)),
127 request_(request) {
128 memset(&status_, 0, sizeof(status_));
129}
130
131ProcessGroupMPI::AsyncWork::~AsyncWork() {
132 if (request_ != MPI_REQUEST_NULL) {
133 std::cerr
134 << "Attempted destruction of AsyncWork before work has completed, "
135 << "terminating the program." << std::endl;
136 std::terminate();
137 }
138}
139
140bool ProcessGroupMPI::AsyncWork::isCompleted() {
141 if (request_ == MPI_REQUEST_NULL) {
142 return true;
143 }
144
145 std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
146 int flag = 0;
147 MPI_CHECK(MPI_Test(&request_, &flag, &status_));
148 if (request_ != MPI_REQUEST_NULL) {
149 return false;
150 }
151
152 // request_ == MPI_REQUEST_NULL; the work has completed
153 // Populate exception if request was not successful
154 if (status_.MPI_ERROR != MPI_SUCCESS) {
155 populateException();
156 }
157
158 return true;
159}
160
161bool ProcessGroupMPI::AsyncWork::isSuccess() const {
162 if (request_ != MPI_REQUEST_NULL) {
163 TORCH_CHECK(
164 false,
165 "Invalid call to AsyncWork::isSuccess before work has completed");
166 }
167
168 return status_.MPI_ERROR == MPI_SUCCESS;
169}
170
171int ProcessGroupMPI::AsyncWork::sourceRank() const {
172 return status_.MPI_SOURCE;
173}
174
175bool ProcessGroupMPI::AsyncWork::wait(std::chrono::milliseconds /* unused */) {
176 if (request_ == MPI_REQUEST_NULL) {
177 // AsyncWork needs to manually call profiling end callbacks if they are set,
178 // since it does not call ProcessGroup::finish().
179 if (Work::recordFunctionEndCallback_) {
180 Work::recordFunctionEndCallback_();
181 Work::recordFunctionEndCallback_ = nullptr;
182 }
183 return true;
184 }
185
186 std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
187 MPI_CHECK(MPI_Wait(&request_, &status_));
188 auto ok = (status_.MPI_ERROR == MPI_SUCCESS);
189
190 // AsyncWork needs to manually call profiling end callbacks if they are set,
191 // since it does not call ProcessGroup::finish().
192 if (Work::recordFunctionEndCallback_) {
193 Work::recordFunctionEndCallback_();
194 Work::recordFunctionEndCallback_ = nullptr;
195 }
196
197 if (!ok) {
198 populateException();
199 std::rethrow_exception(exception_);
200 }
201 // Always return true, because abort API is not implemented.
202 return true;
203}
204
205void ProcessGroupMPI::AsyncWork::abort(){
206 TORCH_CHECK(false, "ProcessGroupMPI::AsyncWork::abort not implemented.")}
207
208std::vector<at::Tensor> ProcessGroupMPI::AsyncWork::result() {
209 return outputTensors_;
210}
211
212void ProcessGroupMPI::AsyncWork::populateException() {
213 std::array<char, MPI_MAX_ERROR_STRING> buf;
214 int len = buf.size();
215 MPI_CHECK(MPI_Error_string(status_.MPI_ERROR, buf.data(), &len));
216 exception_ =
217 std::make_exception_ptr(std::runtime_error(std::string(buf.data(), len)));
218}
219
220// Static global states
221int ProcessGroupMPI::mpiThreadSupport_ = 0;
222std::mutex ProcessGroupMPI::pgGlobalMutex_;
223// We only want to initialize once
224c10::once_flag ProcessGroupMPI::onceFlagInitMPI;
225
226void ProcessGroupMPI::mpiExit() {
227 std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
228 MPI_CHECK(MPI_Finalize());
229}
230
231void ProcessGroupMPI::initMPIOnce() {
232 // Initialize MPI environment
233 c10::call_once(onceFlagInitMPI, []() {
234 MPI_CHECK(MPI_Init_thread(
235 nullptr, nullptr, MPI_THREAD_SERIALIZED, &mpiThreadSupport_));
236 if (mpiThreadSupport_ < MPI_THREAD_SERIALIZED) {
237 TORCH_CHECK(
238 false,
239 "Used MPI implementation doesn't have the "
240 "minimum level of threading support: "
241 "MPI_THREAD_SERIALIZED. This is required by "
242 "c10d package");
243 }
244 if (std::atexit(ProcessGroupMPI::mpiExit)) {
245 TORCH_CHECK(false, "Fail to register the MPI exit handler");
246 }
247 });
248}
249
250c10::intrusive_ptr<ProcessGroupMPI> ProcessGroupMPI::createProcessGroupMPI(
251 std::vector<int> ranks) {
252 // Once initialization
253 initMPIOnce();
254
255 MPI_Comm groupComm = MPI_COMM_WORLD;
256 int rank = -1;
257 int size = -1;
258
259 {
260 std::lock_guard<std::mutex> globalLock(pgGlobalMutex_);
261
262 // If no ranks are specified, assume we're creating the root group
263 if (!ranks.empty()) {
264 MPI_Group worldGroup;
265 MPI_Group ranksGroup;
266 MPI_CHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup));
267 MPI_CHECK(
268 MPI_Group_incl(worldGroup, ranks.size(), ranks.data(), &ranksGroup));
269 // `MPI_Comm_create` can be flaky in certain cases.
270 // See: https://github.com/pytorch/pytorch/issues/53899
271 constexpr int kMaxNumRetries = 3;
272 bool groupComm_updated = false;
273 MPI_Barrier(MPI_COMM_WORLD);
274 for (const auto i : c10::irange(kMaxNumRetries)) {
275 (void)i;
276 if (MPI_Comm_create(MPI_COMM_WORLD, ranksGroup, &groupComm)) {
277 groupComm_updated = true;
278 break;
279 }
280 }
281 MPI_CHECK(groupComm_updated);
282 MPI_CHECK(MPI_Group_free(&worldGroup));
283 MPI_CHECK(MPI_Group_free(&ranksGroup));
284 }
285
286 // Fetch rank and world size for this group (MPI_COMM_WORLD or new)
287 if (groupComm != MPI_COMM_NULL) {
288 MPI_CHECK(MPI_Comm_rank(groupComm, &rank));
289 MPI_CHECK(MPI_Comm_size(groupComm, &size));
290
291 if (rank < 0 || size < 0) {
292 TORCH_CHECK(false, "Failed to get the world_size / rank");
293 }
294 }
295 }
296
297 // If this process is not part of the group, we don't construct a
298 // process group instance. This is in line with the semantics of the
299 // other process group types.
300 if (groupComm == MPI_COMM_NULL) {
301 return c10::intrusive_ptr<ProcessGroupMPI>();
302 }
303
304 return c10::make_intrusive<ProcessGroupMPI>(rank, size, groupComm);
305}
306
307ProcessGroupMPI::ProcessGroupMPI(int rank, int size, MPI_Comm pgComm)
308 : Backend(rank, size), stop_(false), pgComm_(pgComm) {
309 if (pgComm_ == MPI_COMM_NULL) {
310 TORCH_CHECK(false, "pgComm_ must not be MPI_COMM_NULL");
311 }
312
313 // Start the worker thread accepting MPI calls
314 workerThread_ = std::thread(&ProcessGroupMPI::runLoop, this);
315
316 init();
317}
318
319ProcessGroupMPI::~ProcessGroupMPI() {
320 destroy();
321}
322
323void ProcessGroupMPI::destroy() {
324 std::unique_lock<std::mutex> lock(pgMutex_);
325 queueConsumeCV_.wait(lock, [&] { return queue_.empty(); });
326
327 // Queue is empty, signal stop
328 stop_ = true;
329
330 // Release lock to allow threads to terminate
331 lock.unlock();
332 queueProduceCV_.notify_all();
333
334 // Join the single worker thread
335 workerThread_.join();
336}
337
338void ProcessGroupMPI::abort() {
339 destroy();
340 MPI_Abort(pgComm_, EXIT_FAILURE);
341}
342
343void ProcessGroupMPI::runLoop() {
344 std::unique_lock<std::mutex> lock(pgMutex_);
345
346 while (!stop_) {
347 if (queue_.empty()) {
348 queueProduceCV_.wait(lock);
349 continue;
350 }
351
352 auto workTuple = std::move(queue_.front());
353
354 queue_.pop_front();
355
356 auto& workEntry = std::get<0>(workTuple);
357 auto& work = std::get<1>(workTuple);
358
359 lock.unlock();
360 queueConsumeCV_.notify_one();
361
362 try {
363 workEntry->run(workEntry);
364 work->finishWorkMPI();
365 } catch (...) {
366 work->finishWorkMPIError(std::current_exception());
367 }
368
369 lock.lock();
370 }
371}
372
373c10::intrusive_ptr<Work> ProcessGroupMPI::enqueue(
374 std::unique_ptr<WorkEntry> entry,
375 const char* profilingTitle,
376 const c10::optional<std::vector<at::Tensor>>& inputTensors) {
377 auto work =
378 c10::make_intrusive<WorkMPI>(entry->dst, profilingTitle, inputTensors);
379 std::unique_lock<std::mutex> lock(pgMutex_);
380 queue_.push_back(std::make_tuple(std::move(entry), work));
381 lock.unlock();
382 queueProduceCV_.notify_one();
383 return work;
384}
385
386c10::intrusive_ptr<Work> ProcessGroupMPI::broadcast(
387 std::vector<at::Tensor>& tensors,
388 const BroadcastOptions& opts) {
389 checkSingleTensor(tensors);
390 std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
391 [opts, this](std::unique_ptr<WorkEntry>& entry) {
392 auto data = (entry->src)[0];
393 c10::DeviceGuard guard(data.device());
394 std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
395 MPI_CHECK(MPI_Bcast(
396 data.data_ptr(),
397 data.numel(),
398 mpiDatatype.at(data.scalar_type()),
399 opts.rootRank,
400 pgComm_));
401 };
402 auto entry =
403 std::make_unique<WorkEntry>(&tensors, &tensors, std::move(runFunc));
404 return enqueue(
405 std::move(entry),
406 "mpi:broadcast",
407 c10::optional<std::vector<at::Tensor>>(tensors));
408}
409
410c10::intrusive_ptr<Work> ProcessGroupMPI::allreduce(
411 std::vector<at::Tensor>& tensors,
412 const AllreduceOptions& opts) {
413 checkSingleTensor(tensors);
414
415 std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
416 [opts, this](std::unique_ptr<WorkEntry>& entry) {
417 auto data = (entry->src)[0];
418 c10::DeviceGuard guard(data.device());
419 std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
420 MPI_CHECK(MPI_Allreduce(
421 MPI_IN_PLACE,
422 data.data_ptr(),
423 data.numel(),
424 mpiDatatype.at(data.scalar_type()),
425 mpiOp.at(opts.reduceOp),
426 pgComm_));
427 };
428 auto entry =
429 std::make_unique<WorkEntry>(&tensors, &tensors, std::move(runFunc));
430 return enqueue(
431 std::move(entry),
432 "mpi:all_reduce",
433 c10::optional<std::vector<at::Tensor>>(tensors));
434}
435
436c10::intrusive_ptr<Work> ProcessGroupMPI::allreduce_coalesced(
437 std::vector<at::Tensor>& tensors,
438 const AllreduceCoalescedOptions& opts) {
439 TORCH_CHECK(false, "allreduce_coalesced is currently not supported with MPI");
440}
441
442c10::intrusive_ptr<Work> ProcessGroupMPI::reduce(
443 std::vector<at::Tensor>& tensors,
444 const ReduceOptions& opts) {
445 checkSingleTensor(tensors);
446
447 std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
448 [opts, this](std::unique_ptr<WorkEntry>& entry) {
449 auto data = (entry->src)[0];
450 auto dataPtr = (entry->src)[0].data_ptr();
451 void* sendbuf = (rank_ == opts.rootRank) ? MPI_IN_PLACE : dataPtr;
452 void* recvbuf = (rank_ == opts.rootRank) ? dataPtr : nullptr;
453
454 c10::DeviceGuard guard(data.device());
455 std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
456 MPI_CHECK(MPI_Reduce(
457 sendbuf,
458 recvbuf,
459 data.numel(),
460 mpiDatatype.at(data.scalar_type()),
461 mpiOp.at(opts.reduceOp),
462 opts.rootRank,
463 pgComm_));
464 };
465 auto entry =
466 std::make_unique<WorkEntry>(&tensors, &tensors, std::move(runFunc));
467 return enqueue(
468 std::move(entry),
469 "mpi:reduce",
470 c10::optional<std::vector<at::Tensor>>(tensors));
471}
472
473c10::intrusive_ptr<Work> ProcessGroupMPI::allgather(
474 std::vector<std::vector<at::Tensor>>& outputTensors,
475 std::vector<at::Tensor>& inputTensors,
476 const AllgatherOptions& opts) {
477 checkSingleTensor(inputTensors);
478 if (outputTensors.size() != 1) {
479 TORCH_CHECK(
480 false,
481 "MPI process group only supports a single "
482 "tensor op");
483 }
484 if (static_cast<size_t>(size_) != outputTensors[0].size()) {
485 TORCH_CHECK(
486 false,
487 "All gather: number of output tensors should equal "
488 "to the world size");
489 }
490
491 checkSameSizeAndType(inputTensors[0], outputTensors[0]);
492
493 std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
494 [this](std::unique_ptr<WorkEntry>& entry) {
495 auto data = (entry->src)[0];
496 std::vector<at::Tensor> outputDataVec = entry->dst;
497 auto flatOutputTensor = newLikeFlat(outputDataVec);
498
499 c10::DeviceGuard guard(data.device());
500 std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
501 MPI_CHECK(MPI_Allgather(
502 data.data_ptr(),
503 data.numel(),
504 mpiDatatype.at(data.scalar_type()),
505 flatOutputTensor.data_ptr(),
506 data.numel(),
507 mpiDatatype.at(data.scalar_type()),
508 pgComm_));
509
510 for (const auto i : c10::irange(outputDataVec.size())) {
511 outputDataVec[i].copy_(flatOutputTensor[i]);
512 }
513 };
514 auto entry = std::make_unique<WorkEntry>(
515 &inputTensors, &outputTensors[0], std::move(runFunc));
516 return enqueue(
517 std::move(entry),
518 "mpi:all_gather",
519 c10::optional<std::vector<at::Tensor>>(inputTensors));
520}
521
522c10::intrusive_ptr<Work> ProcessGroupMPI::allgather_coalesced(
523 std::vector<std::vector<at::Tensor>>& /* unused */,
524 std::vector<at::Tensor>& /* unused */,
525 const AllgatherOptions& /* unused */) {
526 TORCH_CHECK(false, "ProcessGroupMPI does not support allgather_coalesced");
527}
528
529c10::intrusive_ptr<Work> ProcessGroupMPI::gather(
530 std::vector<std::vector<at::Tensor>>& outputTensors,
531 std::vector<at::Tensor>& inputTensors,
532 const GatherOptions& opts) {
533 checkSingleTensor(inputTensors);
534
535 if (rank_ != opts.rootRank) {
536 if (outputTensors.size() > 0) {
537 TORCH_CHECK(
538 false,
539 "Gather: number of output tensors should be 0 "
540 "for non-root");
541 }
542 } else {
543 if (outputTensors.size() != 1) {
544 TORCH_CHECK(false, "Gather: multi-GPU collective is not supported");
545 }
546 if (static_cast<size_t>(size_) != outputTensors[0].size()) {
547 TORCH_CHECK(
548 false,
549 "Gather: number of output tensors should equal "
550 "to the world size");
551 }
552 checkSameSizeAndType(inputTensors[0], outputTensors[0]);
553 }
554
555 std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
556 [opts, this](std::unique_ptr<WorkEntry>& entry) {
557 auto data = (entry->src)[0];
558 void* recvbuf = nullptr;
559 at::Tensor flatOutputTensor;
560
561 std::vector<at::Tensor> dstdata = entry->dst;
562 if (rank_ == opts.rootRank) {
563 flatOutputTensor = newLikeFlat(dstdata);
564 recvbuf = flatOutputTensor.data_ptr();
565 }
566
567 c10::DeviceGuard guard(data.device());
568 std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
569 MPI_CHECK(MPI_Gather(
570 data.data_ptr(),
571 data.numel(),
572 mpiDatatype.at(data.scalar_type()),
573 recvbuf,
574 data.numel(),
575 mpiDatatype.at(data.scalar_type()),
576 opts.rootRank,
577 pgComm_));
578
579 if (rank_ == opts.rootRank) {
580 const std::vector<at::Tensor>& outputDataVec = entry->dst;
581 // copy the flattened output tensors to the outputs
582 for (const auto i : c10::irange(outputDataVec.size())) {
583 outputDataVec.at(i).copy_(flatOutputTensor[i]);
584 }
585 }
586 };
587
588 if (rank_ == opts.rootRank) {
589 auto entry = std::make_unique<WorkEntry>(
590 &inputTensors, &outputTensors[0], std::move(runFunc));
591 return enqueue(
592 std::move(entry),
593 "mpi:gather",
594 c10::optional<std::vector<at::Tensor>>(inputTensors));
595 } else {
596 auto entry =
597 std::make_unique<WorkEntry>(&inputTensors, nullptr, std::move(runFunc));
598 return enqueue(
599 std::move(entry),
600 "mpi:gather",
601 c10::optional<std::vector<at::Tensor>>(inputTensors));
602 }
603}
604
605c10::intrusive_ptr<Work> ProcessGroupMPI::scatter(
606 std::vector<at::Tensor>& outputTensors,
607 std::vector<std::vector<at::Tensor>>& inputTensors,
608 const ScatterOptions& opts) {
609 checkSingleTensor(outputTensors);
610
611 if (rank_ != opts.rootRank) {
612 if (inputTensors.size() > 0) {
613 TORCH_CHECK(
614 false,
615 "Scatter: number of input tensors should be 0 "
616 "for non-root");
617 }
618 } else {
619 if (inputTensors.size() != 1) {
620 TORCH_CHECK(false, "Scatter: multi-GPU collective is not supported");
621 }
622 if (static_cast<size_t>(size_) != inputTensors[0].size()) {
623 TORCH_CHECK(
624 false,
625 "Scatter: number of input tensors should equal "
626 "to the world size");
627 }
628 checkSameSizeAndType(outputTensors[0], inputTensors[0]);
629 }
630
631 std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
632 [opts, this](std::unique_ptr<WorkEntry>& entry) {
633 auto data = (entry->dst)[0];
634 void* sendbuf = nullptr;
635 at::Tensor flatInputTensor;
636
637 if (rank_ == opts.rootRank) {
638 std::vector<at::Tensor>& inputDataVec = entry->src;
639 flatInputTensor = newLikeFlat(inputDataVec);
640 sendbuf = flatInputTensor.data_ptr();
641
642 // copy the input tensors to the flatten large send buffer
643 for (const auto i : c10::irange(inputDataVec.size())) {
644 flatInputTensor[i].copy_(inputDataVec.at(i));
645 }
646 }
647
648 c10::DeviceGuard guard(data.device());
649 std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
650 MPI_CHECK(MPI_Scatter(
651 sendbuf,
652 data.numel(),
653 mpiDatatype.at(data.scalar_type()),
654 data.data_ptr(),
655 data.numel(),
656 mpiDatatype.at(data.scalar_type()),
657 opts.rootRank,
658 pgComm_));
659 };
660
661 if (rank_ == opts.rootRank) {
662 auto entry = std::make_unique<WorkEntry>(
663 &inputTensors[0], &outputTensors, std::move(runFunc));
664 return enqueue(
665 std::move(entry),
666 "mpi:scatter",
667 inputTensors.size() > 0
668 ? c10::optional<std::vector<at::Tensor>>(inputTensors[0])
669 : c10::nullopt);
670 } else {
671 auto entry = std::make_unique<WorkEntry>(
672 nullptr, &outputTensors, std::move(runFunc));
673 return enqueue(
674 std::move(entry),
675 "mpi:scatter",
676 inputTensors.size() > 0
677 ? c10::optional<std::vector<at::Tensor>>(inputTensors[0])
678 : c10::nullopt);
679 }
680}
681
682c10::intrusive_ptr<Work> ProcessGroupMPI::reduce_scatter(
683 std::vector<at::Tensor>& outputTensors,
684 std::vector<std::vector<at::Tensor>>& inputTensors,
685 const ReduceScatterOptions& opts) {
686 TORCH_CHECK(false, "ProcessGroupMPI does not support reduce_scatter");
687}
688
689c10::intrusive_ptr<Work> ProcessGroupMPI::alltoall_base(
690 at::Tensor& outputTensor,
691 at::Tensor& inputTensor,
692 std::vector<int64_t>& outputSplitSizes,
693 std::vector<int64_t>& inputSplitSizes,
694 const AllToAllOptions& opts) {
695 checkSingleTensorHelper(inputTensor);
696 checkSingleTensorHelper(outputTensor);
697
698 if (outputSplitSizes.size() == 0 && inputSplitSizes.size() == 0) {
699 // We can use alltoall
700 TORCH_CHECK(
701 outputTensor.numel() == inputTensor.numel() &&
702 outputTensor.type() == inputTensor.type(),
703 "Tensors are not equal in size or data type");
704 TORCH_CHECK(
705 outputTensor.size(0) % size_ == 0,
706 "Tensor's dim 0 does not divide equally across group size");
707
708 std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
709 [this](std::unique_ptr<WorkEntry>& entry) {
710 auto srcdata = (entry->src)[0];
711 auto dstdata = (entry->dst)[0];
712 c10::DeviceGuard guard(srcdata.device());
713 std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
714 MPI_CHECK(MPI_Alltoall(
715 srcdata.data_ptr(),
716 srcdata.numel() / size_,
717 mpiDatatype.at(srcdata.scalar_type()),
718 dstdata.data_ptr(),
719 dstdata.numel() / size_,
720 mpiDatatype.at(dstdata.scalar_type()),
721 pgComm_));
722 };
723 std::vector<at::Tensor> inputTensors = {inputTensor};
724 std::vector<at::Tensor> outputTensors = {outputTensor};
725 auto entry = std::make_unique<WorkEntry>(
726 &inputTensors, &outputTensors, std::move(runFunc));
727 return enqueue(
728 std::move(entry),
729 "mpi:all_to_all",
730 c10::optional<std::vector<at::Tensor>>(inputTensors));
731 } else {
732 // Need alltoallv
733 c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_);
734 c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_);
735 std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
736 [this, inputSplitSizes, outputSplitSizes](
737 std::unique_ptr<WorkEntry>& entry) {
738 auto srcdata = (entry->src)[0];
739 auto dstdata = (entry->dst)[0];
740 std::vector<int> send_lengths(size_);
741 std::vector<int> recv_lengths(size_);
742 std::vector<int> send_offsets(size_);
743 std::vector<int> recv_offsets(size_);
744 c10d::computeLengthsAndOffsets(
745 inputSplitSizes, srcdata, &send_lengths, &send_offsets);
746 c10d::computeLengthsAndOffsets(
747 outputSplitSizes, dstdata, &recv_lengths, &recv_offsets);
748 c10::DeviceGuard guard(srcdata.device());
749 std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
750 MPI_CHECK(MPI_Alltoallv(
751 srcdata.data_ptr(),
752 send_lengths.data(),
753 send_offsets.data(),
754 mpiDatatype.at(srcdata.scalar_type()),
755 dstdata.data_ptr(),
756 recv_lengths.data(),
757 recv_offsets.data(),
758 mpiDatatype.at(dstdata.scalar_type()),
759 pgComm_));
760 };
761 std::vector<at::Tensor> inputTensors = {inputTensor};
762 std::vector<at::Tensor> outputTensors = {outputTensor};
763 auto entry = std::make_unique<WorkEntry>(
764 &inputTensors, &outputTensors, std::move(runFunc));
765 return enqueue(
766 std::move(entry),
767 "mpi:all_to_all",
768 c10::optional<std::vector<at::Tensor>>(inputTensors));
769 }
770}
771
772c10::intrusive_ptr<Work> ProcessGroupMPI::alltoall(
773 std::vector<at::Tensor>& outputTensors,
774 std::vector<at::Tensor>& inputTensors,
775 const AllToAllOptions& opts) {
776 TORCH_CHECK(
777 inputTensors.size() == size_,
778 "Number of input tensors are not equal to group size");
779 TORCH_CHECK(
780 outputTensors.size() == size_,
781 "Number of output tensors are not equal to group size");
782 std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
783 [this](std::unique_ptr<WorkEntry>& entry) {
784 std::vector<int> send_lengths(size_);
785 std::vector<int> recv_lengths(size_);
786 std::vector<int> send_offsets(size_);
787 std::vector<int> recv_offsets(size_);
788 auto srcdata = entry->src;
789 auto dstdata = entry->dst;
790 int64_t src_len = c10d::computeLengthsAndOffsets(
791 srcdata, &send_lengths, &send_offsets);
792 int64_t dst_len = c10d::computeLengthsAndOffsets(
793 dstdata, &recv_lengths, &recv_offsets);
794 std::vector<int64_t> send_lengthsL(
795 send_lengths.begin(), send_lengths.end());
796 std::vector<int64_t> recv_lengthsL(
797 recv_lengths.begin(), recv_lengths.end());
798 at::Tensor srcFlatData = at::empty({src_len}, srcdata[0].options());
799 at::Tensor dstFlatData = at::empty({dst_len}, dstdata[0].options());
800 auto srcFlatDataSplits =
801 srcFlatData.split_with_sizes(c10::IntArrayRef(send_lengthsL), 0);
802 for (const auto i : c10::irange(size_)) {
803 srcFlatDataSplits[i].copy_(srcdata[i].view({-1}));
804 }
805 c10::DeviceGuard guard1(srcdata[0].device());
806 std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
807 MPI_CHECK(MPI_Alltoallv(
808 srcFlatData.data_ptr(),
809 send_lengths.data(),
810 send_offsets.data(),
811 mpiDatatype.at(srcdata[0].scalar_type()),
812 dstFlatData.data_ptr(),
813 recv_lengths.data(),
814 recv_offsets.data(),
815 mpiDatatype.at(dstdata[0].scalar_type()),
816 pgComm_));
817
818 auto dstFlatDataSplits =
819 dstFlatData.split_with_sizes(c10::IntArrayRef(recv_lengthsL), 0);
820 for (const auto i : c10::irange(size_)) {
821 dstdata[i].view({-1}).copy_(dstFlatDataSplits[i]);
822 }
823 };
824 auto entry = std::make_unique<WorkEntry>(
825 &inputTensors, &outputTensors, std::move(runFunc));
826 return enqueue(
827 std::move(entry),
828 "mpi:all_to_all",
829 c10::optional<std::vector<at::Tensor>>(inputTensors));
830}
831
832c10::intrusive_ptr<Work> ProcessGroupMPI::send(
833 std::vector<at::Tensor>& tensors,
834 int dstRank,
835 int tag) {
836 checkSingleTensor(tensors);
837
838 auto& tensor = tensors[0];
839 MPI_Request request = MPI_REQUEST_NULL;
840
841 {
842 c10::DeviceGuard guard(tensor.device());
843 std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
844 MPI_CHECK(MPI_Isend(
845 tensor.data_ptr(),
846 tensor.numel(),
847 mpiDatatype.at(tensor.scalar_type()),
848 dstRank,
849 tag,
850 pgComm_,
851 &request));
852 }
853
854 return c10::make_intrusive<AsyncWork>(
855 request,
856 std::vector<at::Tensor>(),
857 "mpi:send",
858 c10::optional<std::vector<at::Tensor>>(tensors));
859}
860
861c10::intrusive_ptr<Work> ProcessGroupMPI::recv(
862 std::vector<at::Tensor>& tensors,
863 int srcRank,
864 int tag) {
865 checkSingleTensor(tensors);
866
867 auto& tensor = tensors[0];
868 MPI_Request request = MPI_REQUEST_NULL;
869
870 {
871 c10::DeviceGuard guard(tensor.device());
872 std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
873 MPI_CHECK(MPI_Irecv(
874 tensor.data_ptr(),
875 tensor.numel(),
876 mpiDatatype.at(tensor.scalar_type()),
877 srcRank,
878 tag,
879 pgComm_,
880 &request));
881 }
882
883 return c10::make_intrusive<AsyncWork>(
884 request,
885 tensors,
886 "mpi:recv",
887 c10::optional<std::vector<at::Tensor>>(tensors));
888}
889
890c10::intrusive_ptr<Work> ProcessGroupMPI::recvAnysource(
891 std::vector<at::Tensor>& tensors,
892 int tag) {
893 checkSingleTensor(tensors);
894
895 auto& tensor = tensors[0];
896 MPI_Request request = MPI_REQUEST_NULL;
897
898 {
899 c10::DeviceGuard guard(tensor.device());
900 std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
901 MPI_CHECK(MPI_Irecv(
902 tensor.data_ptr(),
903 tensor.numel(),
904 mpiDatatype.at(tensor.scalar_type()),
905 MPI_ANY_SOURCE,
906 tag,
907 pgComm_,
908 &request));
909 }
910
911 return c10::make_intrusive<AsyncWork>(
912 request,
913 tensors,
914 "mpi:recvAnySource",
915 c10::optional<std::vector<at::Tensor>>(tensors));
916}
917
918c10::intrusive_ptr<Work> ProcessGroupMPI::barrier(const BarrierOptions& opts) {
919 std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
920 [this](std::unique_ptr<WorkEntry>& entry) {
921 std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
922 MPI_CHECK(MPI_Barrier(pgComm_));
923 };
924 auto entry =
925 std::make_unique<WorkEntry>(nullptr, nullptr, std::move(runFunc));
926 return enqueue(std::move(entry), "mpi:barrier", c10::nullopt);
927}
928
929c10::intrusive_ptr<Work> ProcessGroupMPI::_allgather_base(
930 at::Tensor& /*unused */,
931 at::Tensor& /*unused */,
932 const AllgatherOptions& /*unused */) {
933 TORCH_CHECK(false, "no support for _allgather_base in MPI process group");
934}
935
936} // namespace c10d
937
938#endif // USE_C10D_MPI
939