1 | #include <torch/csrc/distributed/c10d/ProcessGroupMPI.hpp> |
2 | |
3 | #ifdef USE_C10D_MPI |
4 | |
5 | #include <iostream> |
6 | #include <limits> |
7 | #include <map> |
8 | |
9 | #include <c10/core/DeviceGuard.h> |
10 | #include <c10/util/irange.h> |
11 | |
12 | #if defined(OPEN_MPI) && OPEN_MPI |
13 | #include <mpi-ext.h> // Needed for CUDA-aware check |
14 | #endif |
15 | |
16 | namespace c10d { |
17 | |
18 | #define MPI_CHECK(cmd) \ |
19 | do { \ |
20 | int mpiStatus = cmd; \ |
21 | if (mpiStatus != MPI_SUCCESS) { \ |
22 | std::string err = "MPI error in: " + std::string(__FILE__) + ":" + \ |
23 | std::to_string(__LINE__) + \ |
24 | ", with error code: " + std::to_string(mpiStatus); \ |
25 | TORCH_CHECK(false, err); \ |
26 | } \ |
27 | } while (0) |
28 | |
29 | namespace { |
30 | |
31 | // Op mapping |
32 | std::map<ReduceOp::RedOpType, MPI_Op> mpiOp = { |
33 | {ReduceOp::MIN, MPI_MIN}, |
34 | {ReduceOp::MAX, MPI_MAX}, |
35 | {ReduceOp::SUM, MPI_SUM}, |
36 | {ReduceOp::PRODUCT, MPI_PROD}, |
37 | }; |
38 | // Type mapping |
39 | std::map<at::ScalarType, MPI_Datatype> mpiDatatype = { |
40 | {at::kByte, MPI_UNSIGNED_CHAR}, |
41 | {at::kChar, MPI_CHAR}, |
42 | {at::kDouble, MPI_DOUBLE}, |
43 | {at::kFloat, MPI_FLOAT}, |
44 | {at::kInt, MPI_INT}, |
45 | {at::kLong, MPI_LONG}, |
46 | {at::kShort, MPI_SHORT}, |
47 | }; |
48 | |
49 | // Checking CUDA-aware MPI support, currently we only support CUDA aware |
50 | // MPI ops through Open MPI |
51 | bool cudaAwareMpiCheck() { |
52 | // Run time check |
53 | #if defined(MPIX_CUDA_AWARE_SUPPORT) |
54 | if (MPIX_Query_cuda_support() == 1) { |
55 | return true; |
56 | } else { |
57 | return false; |
58 | } |
59 | #else // !defined(MPIX_CUDA_AWARE_SUPPORT) |
60 | return false; |
61 | #endif // MPIX_CUDA_AWARE_SUPPORT |
62 | } |
63 | |
64 | // Checking the input tensor's validity |
65 | void checkSingleTensorHelper(const at::Tensor& tensor) { |
66 | if (!tensor.is_contiguous()) { |
67 | TORCH_CHECK(false, "input tensor has to be contiguous" ); |
68 | } |
69 | if (tensor.is_sparse()) { |
70 | TORCH_CHECK(false, "input tensor has to be dense" ); |
71 | } |
72 | if (tensor.is_cuda() && !cudaAwareMpiCheck()) { |
73 | TORCH_CHECK( |
74 | false, |
75 | "CUDA tensor detected and the MPI used doesn't " |
76 | "have CUDA-aware MPI support" ); |
77 | } |
78 | } |
79 | |
80 | void checkSingleTensor(const std::vector<at::Tensor>& tensors) { |
81 | if (tensors.size() != 1) { |
82 | TORCH_CHECK( |
83 | false, "MPI process group does not support multi-GPU collectives" ); |
84 | } |
85 | checkSingleTensorHelper(tensors[0]); |
86 | } |
87 | |
88 | void checkSameSizeAndType( |
89 | const at::Tensor& t_in, |
90 | const std::vector<at::Tensor>& tensors) { |
91 | for (const auto& tensor : tensors) { |
92 | if ((tensor.numel() != t_in.numel()) || |
93 | (tensor.scalar_type() != t_in.scalar_type())) { |
94 | TORCH_CHECK(false, "Tensors are not equal in size or data type" ); |
95 | } |
96 | checkSingleTensorHelper(tensor); |
97 | } |
98 | } |
99 | |
100 | } // namespace |
101 | |
102 | std::vector<at::Tensor> ProcessGroupMPI::WorkMPI::result() { |
103 | return outputTensors_; |
104 | } |
105 | |
106 | c10::intrusive_ptr<c10::ivalue::Future> ProcessGroupMPI::WorkMPI::getFuture() { |
107 | return future_; |
108 | } |
109 | |
110 | void ProcessGroupMPI::WorkMPI::finishWorkMPIError(std::exception_ptr eptr) { |
111 | future_->setError(eptr); |
112 | finish(eptr); |
113 | } |
114 | |
115 | void ProcessGroupMPI::WorkMPI::finishWorkMPI() { |
116 | future_->markCompleted(at::IValue(outputTensors_)); |
117 | finish(); |
118 | } |
119 | |
120 | ProcessGroupMPI::AsyncWork::AsyncWork( |
121 | MPI_Request request, |
122 | std::vector<at::Tensor> outputTensors, |
123 | const char* profilingTitle, |
124 | const c10::optional<std::vector<at::Tensor>>& inputTensors) |
125 | : Work(-1, OpType::UNKNOWN, profilingTitle, inputTensors), |
126 | outputTensors_(std::move(outputTensors)), |
127 | request_(request) { |
128 | memset(&status_, 0, sizeof(status_)); |
129 | } |
130 | |
131 | ProcessGroupMPI::AsyncWork::~AsyncWork() { |
132 | if (request_ != MPI_REQUEST_NULL) { |
133 | std::cerr |
134 | << "Attempted destruction of AsyncWork before work has completed, " |
135 | << "terminating the program." << std::endl; |
136 | std::terminate(); |
137 | } |
138 | } |
139 | |
140 | bool ProcessGroupMPI::AsyncWork::isCompleted() { |
141 | if (request_ == MPI_REQUEST_NULL) { |
142 | return true; |
143 | } |
144 | |
145 | std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
146 | int flag = 0; |
147 | MPI_CHECK(MPI_Test(&request_, &flag, &status_)); |
148 | if (request_ != MPI_REQUEST_NULL) { |
149 | return false; |
150 | } |
151 | |
152 | // request_ == MPI_REQUEST_NULL; the work has completed |
153 | // Populate exception if request was not successful |
154 | if (status_.MPI_ERROR != MPI_SUCCESS) { |
155 | populateException(); |
156 | } |
157 | |
158 | return true; |
159 | } |
160 | |
161 | bool ProcessGroupMPI::AsyncWork::isSuccess() const { |
162 | if (request_ != MPI_REQUEST_NULL) { |
163 | TORCH_CHECK( |
164 | false, |
165 | "Invalid call to AsyncWork::isSuccess before work has completed" ); |
166 | } |
167 | |
168 | return status_.MPI_ERROR == MPI_SUCCESS; |
169 | } |
170 | |
171 | int ProcessGroupMPI::AsyncWork::sourceRank() const { |
172 | return status_.MPI_SOURCE; |
173 | } |
174 | |
175 | bool ProcessGroupMPI::AsyncWork::wait(std::chrono::milliseconds /* unused */) { |
176 | if (request_ == MPI_REQUEST_NULL) { |
177 | // AsyncWork needs to manually call profiling end callbacks if they are set, |
178 | // since it does not call ProcessGroup::finish(). |
179 | if (Work::recordFunctionEndCallback_) { |
180 | Work::recordFunctionEndCallback_(); |
181 | Work::recordFunctionEndCallback_ = nullptr; |
182 | } |
183 | return true; |
184 | } |
185 | |
186 | std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
187 | MPI_CHECK(MPI_Wait(&request_, &status_)); |
188 | auto ok = (status_.MPI_ERROR == MPI_SUCCESS); |
189 | |
190 | // AsyncWork needs to manually call profiling end callbacks if they are set, |
191 | // since it does not call ProcessGroup::finish(). |
192 | if (Work::recordFunctionEndCallback_) { |
193 | Work::recordFunctionEndCallback_(); |
194 | Work::recordFunctionEndCallback_ = nullptr; |
195 | } |
196 | |
197 | if (!ok) { |
198 | populateException(); |
199 | std::rethrow_exception(exception_); |
200 | } |
201 | // Always return true, because abort API is not implemented. |
202 | return true; |
203 | } |
204 | |
205 | void ProcessGroupMPI::AsyncWork::abort(){ |
206 | TORCH_CHECK(false, "ProcessGroupMPI::AsyncWork::abort not implemented." )} |
207 | |
208 | std::vector<at::Tensor> ProcessGroupMPI::AsyncWork::result() { |
209 | return outputTensors_; |
210 | } |
211 | |
212 | void ProcessGroupMPI::AsyncWork::populateException() { |
213 | std::array<char, MPI_MAX_ERROR_STRING> buf; |
214 | int len = buf.size(); |
215 | MPI_CHECK(MPI_Error_string(status_.MPI_ERROR, buf.data(), &len)); |
216 | exception_ = |
217 | std::make_exception_ptr(std::runtime_error(std::string(buf.data(), len))); |
218 | } |
219 | |
220 | // Static global states |
221 | int ProcessGroupMPI::mpiThreadSupport_ = 0; |
222 | std::mutex ProcessGroupMPI::pgGlobalMutex_; |
223 | // We only want to initialize once |
224 | c10::once_flag ProcessGroupMPI::onceFlagInitMPI; |
225 | |
226 | void ProcessGroupMPI::mpiExit() { |
227 | std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
228 | MPI_CHECK(MPI_Finalize()); |
229 | } |
230 | |
231 | void ProcessGroupMPI::initMPIOnce() { |
232 | // Initialize MPI environment |
233 | c10::call_once(onceFlagInitMPI, []() { |
234 | MPI_CHECK(MPI_Init_thread( |
235 | nullptr, nullptr, MPI_THREAD_SERIALIZED, &mpiThreadSupport_)); |
236 | if (mpiThreadSupport_ < MPI_THREAD_SERIALIZED) { |
237 | TORCH_CHECK( |
238 | false, |
239 | "Used MPI implementation doesn't have the " |
240 | "minimum level of threading support: " |
241 | "MPI_THREAD_SERIALIZED. This is required by " |
242 | "c10d package" ); |
243 | } |
244 | if (std::atexit(ProcessGroupMPI::mpiExit)) { |
245 | TORCH_CHECK(false, "Fail to register the MPI exit handler" ); |
246 | } |
247 | }); |
248 | } |
249 | |
250 | c10::intrusive_ptr<ProcessGroupMPI> ProcessGroupMPI::createProcessGroupMPI( |
251 | std::vector<int> ranks) { |
252 | // Once initialization |
253 | initMPIOnce(); |
254 | |
255 | MPI_Comm groupComm = MPI_COMM_WORLD; |
256 | int rank = -1; |
257 | int size = -1; |
258 | |
259 | { |
260 | std::lock_guard<std::mutex> globalLock(pgGlobalMutex_); |
261 | |
262 | // If no ranks are specified, assume we're creating the root group |
263 | if (!ranks.empty()) { |
264 | MPI_Group worldGroup; |
265 | MPI_Group ranksGroup; |
266 | MPI_CHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup)); |
267 | MPI_CHECK( |
268 | MPI_Group_incl(worldGroup, ranks.size(), ranks.data(), &ranksGroup)); |
269 | // `MPI_Comm_create` can be flaky in certain cases. |
270 | // See: https://github.com/pytorch/pytorch/issues/53899 |
271 | constexpr int kMaxNumRetries = 3; |
272 | bool groupComm_updated = false; |
273 | MPI_Barrier(MPI_COMM_WORLD); |
274 | for (const auto i : c10::irange(kMaxNumRetries)) { |
275 | (void)i; |
276 | if (MPI_Comm_create(MPI_COMM_WORLD, ranksGroup, &groupComm)) { |
277 | groupComm_updated = true; |
278 | break; |
279 | } |
280 | } |
281 | MPI_CHECK(groupComm_updated); |
282 | MPI_CHECK(MPI_Group_free(&worldGroup)); |
283 | MPI_CHECK(MPI_Group_free(&ranksGroup)); |
284 | } |
285 | |
286 | // Fetch rank and world size for this group (MPI_COMM_WORLD or new) |
287 | if (groupComm != MPI_COMM_NULL) { |
288 | MPI_CHECK(MPI_Comm_rank(groupComm, &rank)); |
289 | MPI_CHECK(MPI_Comm_size(groupComm, &size)); |
290 | |
291 | if (rank < 0 || size < 0) { |
292 | TORCH_CHECK(false, "Failed to get the world_size / rank" ); |
293 | } |
294 | } |
295 | } |
296 | |
297 | // If this process is not part of the group, we don't construct a |
298 | // process group instance. This is in line with the semantics of the |
299 | // other process group types. |
300 | if (groupComm == MPI_COMM_NULL) { |
301 | return c10::intrusive_ptr<ProcessGroupMPI>(); |
302 | } |
303 | |
304 | return c10::make_intrusive<ProcessGroupMPI>(rank, size, groupComm); |
305 | } |
306 | |
307 | ProcessGroupMPI::ProcessGroupMPI(int rank, int size, MPI_Comm pgComm) |
308 | : Backend(rank, size), stop_(false), pgComm_(pgComm) { |
309 | if (pgComm_ == MPI_COMM_NULL) { |
310 | TORCH_CHECK(false, "pgComm_ must not be MPI_COMM_NULL" ); |
311 | } |
312 | |
313 | // Start the worker thread accepting MPI calls |
314 | workerThread_ = std::thread(&ProcessGroupMPI::runLoop, this); |
315 | |
316 | init(); |
317 | } |
318 | |
319 | ProcessGroupMPI::~ProcessGroupMPI() { |
320 | destroy(); |
321 | } |
322 | |
323 | void ProcessGroupMPI::destroy() { |
324 | std::unique_lock<std::mutex> lock(pgMutex_); |
325 | queueConsumeCV_.wait(lock, [&] { return queue_.empty(); }); |
326 | |
327 | // Queue is empty, signal stop |
328 | stop_ = true; |
329 | |
330 | // Release lock to allow threads to terminate |
331 | lock.unlock(); |
332 | queueProduceCV_.notify_all(); |
333 | |
334 | // Join the single worker thread |
335 | workerThread_.join(); |
336 | } |
337 | |
338 | void ProcessGroupMPI::abort() { |
339 | destroy(); |
340 | MPI_Abort(pgComm_, EXIT_FAILURE); |
341 | } |
342 | |
343 | void ProcessGroupMPI::runLoop() { |
344 | std::unique_lock<std::mutex> lock(pgMutex_); |
345 | |
346 | while (!stop_) { |
347 | if (queue_.empty()) { |
348 | queueProduceCV_.wait(lock); |
349 | continue; |
350 | } |
351 | |
352 | auto workTuple = std::move(queue_.front()); |
353 | |
354 | queue_.pop_front(); |
355 | |
356 | auto& workEntry = std::get<0>(workTuple); |
357 | auto& work = std::get<1>(workTuple); |
358 | |
359 | lock.unlock(); |
360 | queueConsumeCV_.notify_one(); |
361 | |
362 | try { |
363 | workEntry->run(workEntry); |
364 | work->finishWorkMPI(); |
365 | } catch (...) { |
366 | work->finishWorkMPIError(std::current_exception()); |
367 | } |
368 | |
369 | lock.lock(); |
370 | } |
371 | } |
372 | |
373 | c10::intrusive_ptr<Work> ProcessGroupMPI::enqueue( |
374 | std::unique_ptr<WorkEntry> entry, |
375 | const char* profilingTitle, |
376 | const c10::optional<std::vector<at::Tensor>>& inputTensors) { |
377 | auto work = |
378 | c10::make_intrusive<WorkMPI>(entry->dst, profilingTitle, inputTensors); |
379 | std::unique_lock<std::mutex> lock(pgMutex_); |
380 | queue_.push_back(std::make_tuple(std::move(entry), work)); |
381 | lock.unlock(); |
382 | queueProduceCV_.notify_one(); |
383 | return work; |
384 | } |
385 | |
386 | c10::intrusive_ptr<Work> ProcessGroupMPI::broadcast( |
387 | std::vector<at::Tensor>& tensors, |
388 | const BroadcastOptions& opts) { |
389 | checkSingleTensor(tensors); |
390 | std::function<void(std::unique_ptr<WorkEntry>&)> runFunc = |
391 | [opts, this](std::unique_ptr<WorkEntry>& entry) { |
392 | auto data = (entry->src)[0]; |
393 | c10::DeviceGuard guard(data.device()); |
394 | std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
395 | MPI_CHECK(MPI_Bcast( |
396 | data.data_ptr(), |
397 | data.numel(), |
398 | mpiDatatype.at(data.scalar_type()), |
399 | opts.rootRank, |
400 | pgComm_)); |
401 | }; |
402 | auto entry = |
403 | std::make_unique<WorkEntry>(&tensors, &tensors, std::move(runFunc)); |
404 | return enqueue( |
405 | std::move(entry), |
406 | "mpi:broadcast" , |
407 | c10::optional<std::vector<at::Tensor>>(tensors)); |
408 | } |
409 | |
410 | c10::intrusive_ptr<Work> ProcessGroupMPI::allreduce( |
411 | std::vector<at::Tensor>& tensors, |
412 | const AllreduceOptions& opts) { |
413 | checkSingleTensor(tensors); |
414 | |
415 | std::function<void(std::unique_ptr<WorkEntry>&)> runFunc = |
416 | [opts, this](std::unique_ptr<WorkEntry>& entry) { |
417 | auto data = (entry->src)[0]; |
418 | c10::DeviceGuard guard(data.device()); |
419 | std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
420 | MPI_CHECK(MPI_Allreduce( |
421 | MPI_IN_PLACE, |
422 | data.data_ptr(), |
423 | data.numel(), |
424 | mpiDatatype.at(data.scalar_type()), |
425 | mpiOp.at(opts.reduceOp), |
426 | pgComm_)); |
427 | }; |
428 | auto entry = |
429 | std::make_unique<WorkEntry>(&tensors, &tensors, std::move(runFunc)); |
430 | return enqueue( |
431 | std::move(entry), |
432 | "mpi:all_reduce" , |
433 | c10::optional<std::vector<at::Tensor>>(tensors)); |
434 | } |
435 | |
436 | c10::intrusive_ptr<Work> ProcessGroupMPI::allreduce_coalesced( |
437 | std::vector<at::Tensor>& tensors, |
438 | const AllreduceCoalescedOptions& opts) { |
439 | TORCH_CHECK(false, "allreduce_coalesced is currently not supported with MPI" ); |
440 | } |
441 | |
442 | c10::intrusive_ptr<Work> ProcessGroupMPI::reduce( |
443 | std::vector<at::Tensor>& tensors, |
444 | const ReduceOptions& opts) { |
445 | checkSingleTensor(tensors); |
446 | |
447 | std::function<void(std::unique_ptr<WorkEntry>&)> runFunc = |
448 | [opts, this](std::unique_ptr<WorkEntry>& entry) { |
449 | auto data = (entry->src)[0]; |
450 | auto dataPtr = (entry->src)[0].data_ptr(); |
451 | void* sendbuf = (rank_ == opts.rootRank) ? MPI_IN_PLACE : dataPtr; |
452 | void* recvbuf = (rank_ == opts.rootRank) ? dataPtr : nullptr; |
453 | |
454 | c10::DeviceGuard guard(data.device()); |
455 | std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
456 | MPI_CHECK(MPI_Reduce( |
457 | sendbuf, |
458 | recvbuf, |
459 | data.numel(), |
460 | mpiDatatype.at(data.scalar_type()), |
461 | mpiOp.at(opts.reduceOp), |
462 | opts.rootRank, |
463 | pgComm_)); |
464 | }; |
465 | auto entry = |
466 | std::make_unique<WorkEntry>(&tensors, &tensors, std::move(runFunc)); |
467 | return enqueue( |
468 | std::move(entry), |
469 | "mpi:reduce" , |
470 | c10::optional<std::vector<at::Tensor>>(tensors)); |
471 | } |
472 | |
473 | c10::intrusive_ptr<Work> ProcessGroupMPI::allgather( |
474 | std::vector<std::vector<at::Tensor>>& outputTensors, |
475 | std::vector<at::Tensor>& inputTensors, |
476 | const AllgatherOptions& opts) { |
477 | checkSingleTensor(inputTensors); |
478 | if (outputTensors.size() != 1) { |
479 | TORCH_CHECK( |
480 | false, |
481 | "MPI process group only supports a single " |
482 | "tensor op" ); |
483 | } |
484 | if (static_cast<size_t>(size_) != outputTensors[0].size()) { |
485 | TORCH_CHECK( |
486 | false, |
487 | "All gather: number of output tensors should equal " |
488 | "to the world size" ); |
489 | } |
490 | |
491 | checkSameSizeAndType(inputTensors[0], outputTensors[0]); |
492 | |
493 | std::function<void(std::unique_ptr<WorkEntry>&)> runFunc = |
494 | [this](std::unique_ptr<WorkEntry>& entry) { |
495 | auto data = (entry->src)[0]; |
496 | std::vector<at::Tensor> outputDataVec = entry->dst; |
497 | auto flatOutputTensor = newLikeFlat(outputDataVec); |
498 | |
499 | c10::DeviceGuard guard(data.device()); |
500 | std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
501 | MPI_CHECK(MPI_Allgather( |
502 | data.data_ptr(), |
503 | data.numel(), |
504 | mpiDatatype.at(data.scalar_type()), |
505 | flatOutputTensor.data_ptr(), |
506 | data.numel(), |
507 | mpiDatatype.at(data.scalar_type()), |
508 | pgComm_)); |
509 | |
510 | for (const auto i : c10::irange(outputDataVec.size())) { |
511 | outputDataVec[i].copy_(flatOutputTensor[i]); |
512 | } |
513 | }; |
514 | auto entry = std::make_unique<WorkEntry>( |
515 | &inputTensors, &outputTensors[0], std::move(runFunc)); |
516 | return enqueue( |
517 | std::move(entry), |
518 | "mpi:all_gather" , |
519 | c10::optional<std::vector<at::Tensor>>(inputTensors)); |
520 | } |
521 | |
522 | c10::intrusive_ptr<Work> ProcessGroupMPI::allgather_coalesced( |
523 | std::vector<std::vector<at::Tensor>>& /* unused */, |
524 | std::vector<at::Tensor>& /* unused */, |
525 | const AllgatherOptions& /* unused */) { |
526 | TORCH_CHECK(false, "ProcessGroupMPI does not support allgather_coalesced" ); |
527 | } |
528 | |
529 | c10::intrusive_ptr<Work> ProcessGroupMPI::gather( |
530 | std::vector<std::vector<at::Tensor>>& outputTensors, |
531 | std::vector<at::Tensor>& inputTensors, |
532 | const GatherOptions& opts) { |
533 | checkSingleTensor(inputTensors); |
534 | |
535 | if (rank_ != opts.rootRank) { |
536 | if (outputTensors.size() > 0) { |
537 | TORCH_CHECK( |
538 | false, |
539 | "Gather: number of output tensors should be 0 " |
540 | "for non-root" ); |
541 | } |
542 | } else { |
543 | if (outputTensors.size() != 1) { |
544 | TORCH_CHECK(false, "Gather: multi-GPU collective is not supported" ); |
545 | } |
546 | if (static_cast<size_t>(size_) != outputTensors[0].size()) { |
547 | TORCH_CHECK( |
548 | false, |
549 | "Gather: number of output tensors should equal " |
550 | "to the world size" ); |
551 | } |
552 | checkSameSizeAndType(inputTensors[0], outputTensors[0]); |
553 | } |
554 | |
555 | std::function<void(std::unique_ptr<WorkEntry>&)> runFunc = |
556 | [opts, this](std::unique_ptr<WorkEntry>& entry) { |
557 | auto data = (entry->src)[0]; |
558 | void* recvbuf = nullptr; |
559 | at::Tensor flatOutputTensor; |
560 | |
561 | std::vector<at::Tensor> dstdata = entry->dst; |
562 | if (rank_ == opts.rootRank) { |
563 | flatOutputTensor = newLikeFlat(dstdata); |
564 | recvbuf = flatOutputTensor.data_ptr(); |
565 | } |
566 | |
567 | c10::DeviceGuard guard(data.device()); |
568 | std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
569 | MPI_CHECK(MPI_Gather( |
570 | data.data_ptr(), |
571 | data.numel(), |
572 | mpiDatatype.at(data.scalar_type()), |
573 | recvbuf, |
574 | data.numel(), |
575 | mpiDatatype.at(data.scalar_type()), |
576 | opts.rootRank, |
577 | pgComm_)); |
578 | |
579 | if (rank_ == opts.rootRank) { |
580 | const std::vector<at::Tensor>& outputDataVec = entry->dst; |
581 | // copy the flattened output tensors to the outputs |
582 | for (const auto i : c10::irange(outputDataVec.size())) { |
583 | outputDataVec.at(i).copy_(flatOutputTensor[i]); |
584 | } |
585 | } |
586 | }; |
587 | |
588 | if (rank_ == opts.rootRank) { |
589 | auto entry = std::make_unique<WorkEntry>( |
590 | &inputTensors, &outputTensors[0], std::move(runFunc)); |
591 | return enqueue( |
592 | std::move(entry), |
593 | "mpi:gather" , |
594 | c10::optional<std::vector<at::Tensor>>(inputTensors)); |
595 | } else { |
596 | auto entry = |
597 | std::make_unique<WorkEntry>(&inputTensors, nullptr, std::move(runFunc)); |
598 | return enqueue( |
599 | std::move(entry), |
600 | "mpi:gather" , |
601 | c10::optional<std::vector<at::Tensor>>(inputTensors)); |
602 | } |
603 | } |
604 | |
605 | c10::intrusive_ptr<Work> ProcessGroupMPI::scatter( |
606 | std::vector<at::Tensor>& outputTensors, |
607 | std::vector<std::vector<at::Tensor>>& inputTensors, |
608 | const ScatterOptions& opts) { |
609 | checkSingleTensor(outputTensors); |
610 | |
611 | if (rank_ != opts.rootRank) { |
612 | if (inputTensors.size() > 0) { |
613 | TORCH_CHECK( |
614 | false, |
615 | "Scatter: number of input tensors should be 0 " |
616 | "for non-root" ); |
617 | } |
618 | } else { |
619 | if (inputTensors.size() != 1) { |
620 | TORCH_CHECK(false, "Scatter: multi-GPU collective is not supported" ); |
621 | } |
622 | if (static_cast<size_t>(size_) != inputTensors[0].size()) { |
623 | TORCH_CHECK( |
624 | false, |
625 | "Scatter: number of input tensors should equal " |
626 | "to the world size" ); |
627 | } |
628 | checkSameSizeAndType(outputTensors[0], inputTensors[0]); |
629 | } |
630 | |
631 | std::function<void(std::unique_ptr<WorkEntry>&)> runFunc = |
632 | [opts, this](std::unique_ptr<WorkEntry>& entry) { |
633 | auto data = (entry->dst)[0]; |
634 | void* sendbuf = nullptr; |
635 | at::Tensor flatInputTensor; |
636 | |
637 | if (rank_ == opts.rootRank) { |
638 | std::vector<at::Tensor>& inputDataVec = entry->src; |
639 | flatInputTensor = newLikeFlat(inputDataVec); |
640 | sendbuf = flatInputTensor.data_ptr(); |
641 | |
642 | // copy the input tensors to the flatten large send buffer |
643 | for (const auto i : c10::irange(inputDataVec.size())) { |
644 | flatInputTensor[i].copy_(inputDataVec.at(i)); |
645 | } |
646 | } |
647 | |
648 | c10::DeviceGuard guard(data.device()); |
649 | std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
650 | MPI_CHECK(MPI_Scatter( |
651 | sendbuf, |
652 | data.numel(), |
653 | mpiDatatype.at(data.scalar_type()), |
654 | data.data_ptr(), |
655 | data.numel(), |
656 | mpiDatatype.at(data.scalar_type()), |
657 | opts.rootRank, |
658 | pgComm_)); |
659 | }; |
660 | |
661 | if (rank_ == opts.rootRank) { |
662 | auto entry = std::make_unique<WorkEntry>( |
663 | &inputTensors[0], &outputTensors, std::move(runFunc)); |
664 | return enqueue( |
665 | std::move(entry), |
666 | "mpi:scatter" , |
667 | inputTensors.size() > 0 |
668 | ? c10::optional<std::vector<at::Tensor>>(inputTensors[0]) |
669 | : c10::nullopt); |
670 | } else { |
671 | auto entry = std::make_unique<WorkEntry>( |
672 | nullptr, &outputTensors, std::move(runFunc)); |
673 | return enqueue( |
674 | std::move(entry), |
675 | "mpi:scatter" , |
676 | inputTensors.size() > 0 |
677 | ? c10::optional<std::vector<at::Tensor>>(inputTensors[0]) |
678 | : c10::nullopt); |
679 | } |
680 | } |
681 | |
682 | c10::intrusive_ptr<Work> ProcessGroupMPI::reduce_scatter( |
683 | std::vector<at::Tensor>& outputTensors, |
684 | std::vector<std::vector<at::Tensor>>& inputTensors, |
685 | const ReduceScatterOptions& opts) { |
686 | TORCH_CHECK(false, "ProcessGroupMPI does not support reduce_scatter" ); |
687 | } |
688 | |
689 | c10::intrusive_ptr<Work> ProcessGroupMPI::alltoall_base( |
690 | at::Tensor& outputTensor, |
691 | at::Tensor& inputTensor, |
692 | std::vector<int64_t>& outputSplitSizes, |
693 | std::vector<int64_t>& inputSplitSizes, |
694 | const AllToAllOptions& opts) { |
695 | checkSingleTensorHelper(inputTensor); |
696 | checkSingleTensorHelper(outputTensor); |
697 | |
698 | if (outputSplitSizes.size() == 0 && inputSplitSizes.size() == 0) { |
699 | // We can use alltoall |
700 | TORCH_CHECK( |
701 | outputTensor.numel() == inputTensor.numel() && |
702 | outputTensor.type() == inputTensor.type(), |
703 | "Tensors are not equal in size or data type" ); |
704 | TORCH_CHECK( |
705 | outputTensor.size(0) % size_ == 0, |
706 | "Tensor's dim 0 does not divide equally across group size" ); |
707 | |
708 | std::function<void(std::unique_ptr<WorkEntry>&)> runFunc = |
709 | [this](std::unique_ptr<WorkEntry>& entry) { |
710 | auto srcdata = (entry->src)[0]; |
711 | auto dstdata = (entry->dst)[0]; |
712 | c10::DeviceGuard guard(srcdata.device()); |
713 | std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
714 | MPI_CHECK(MPI_Alltoall( |
715 | srcdata.data_ptr(), |
716 | srcdata.numel() / size_, |
717 | mpiDatatype.at(srcdata.scalar_type()), |
718 | dstdata.data_ptr(), |
719 | dstdata.numel() / size_, |
720 | mpiDatatype.at(dstdata.scalar_type()), |
721 | pgComm_)); |
722 | }; |
723 | std::vector<at::Tensor> inputTensors = {inputTensor}; |
724 | std::vector<at::Tensor> outputTensors = {outputTensor}; |
725 | auto entry = std::make_unique<WorkEntry>( |
726 | &inputTensors, &outputTensors, std::move(runFunc)); |
727 | return enqueue( |
728 | std::move(entry), |
729 | "mpi:all_to_all" , |
730 | c10::optional<std::vector<at::Tensor>>(inputTensors)); |
731 | } else { |
732 | // Need alltoallv |
733 | c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_); |
734 | c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_); |
735 | std::function<void(std::unique_ptr<WorkEntry>&)> runFunc = |
736 | [this, inputSplitSizes, outputSplitSizes]( |
737 | std::unique_ptr<WorkEntry>& entry) { |
738 | auto srcdata = (entry->src)[0]; |
739 | auto dstdata = (entry->dst)[0]; |
740 | std::vector<int> send_lengths(size_); |
741 | std::vector<int> recv_lengths(size_); |
742 | std::vector<int> send_offsets(size_); |
743 | std::vector<int> recv_offsets(size_); |
744 | c10d::computeLengthsAndOffsets( |
745 | inputSplitSizes, srcdata, &send_lengths, &send_offsets); |
746 | c10d::computeLengthsAndOffsets( |
747 | outputSplitSizes, dstdata, &recv_lengths, &recv_offsets); |
748 | c10::DeviceGuard guard(srcdata.device()); |
749 | std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
750 | MPI_CHECK(MPI_Alltoallv( |
751 | srcdata.data_ptr(), |
752 | send_lengths.data(), |
753 | send_offsets.data(), |
754 | mpiDatatype.at(srcdata.scalar_type()), |
755 | dstdata.data_ptr(), |
756 | recv_lengths.data(), |
757 | recv_offsets.data(), |
758 | mpiDatatype.at(dstdata.scalar_type()), |
759 | pgComm_)); |
760 | }; |
761 | std::vector<at::Tensor> inputTensors = {inputTensor}; |
762 | std::vector<at::Tensor> outputTensors = {outputTensor}; |
763 | auto entry = std::make_unique<WorkEntry>( |
764 | &inputTensors, &outputTensors, std::move(runFunc)); |
765 | return enqueue( |
766 | std::move(entry), |
767 | "mpi:all_to_all" , |
768 | c10::optional<std::vector<at::Tensor>>(inputTensors)); |
769 | } |
770 | } |
771 | |
772 | c10::intrusive_ptr<Work> ProcessGroupMPI::alltoall( |
773 | std::vector<at::Tensor>& outputTensors, |
774 | std::vector<at::Tensor>& inputTensors, |
775 | const AllToAllOptions& opts) { |
776 | TORCH_CHECK( |
777 | inputTensors.size() == size_, |
778 | "Number of input tensors are not equal to group size" ); |
779 | TORCH_CHECK( |
780 | outputTensors.size() == size_, |
781 | "Number of output tensors are not equal to group size" ); |
782 | std::function<void(std::unique_ptr<WorkEntry>&)> runFunc = |
783 | [this](std::unique_ptr<WorkEntry>& entry) { |
784 | std::vector<int> send_lengths(size_); |
785 | std::vector<int> recv_lengths(size_); |
786 | std::vector<int> send_offsets(size_); |
787 | std::vector<int> recv_offsets(size_); |
788 | auto srcdata = entry->src; |
789 | auto dstdata = entry->dst; |
790 | int64_t src_len = c10d::computeLengthsAndOffsets( |
791 | srcdata, &send_lengths, &send_offsets); |
792 | int64_t dst_len = c10d::computeLengthsAndOffsets( |
793 | dstdata, &recv_lengths, &recv_offsets); |
794 | std::vector<int64_t> send_lengthsL( |
795 | send_lengths.begin(), send_lengths.end()); |
796 | std::vector<int64_t> recv_lengthsL( |
797 | recv_lengths.begin(), recv_lengths.end()); |
798 | at::Tensor srcFlatData = at::empty({src_len}, srcdata[0].options()); |
799 | at::Tensor dstFlatData = at::empty({dst_len}, dstdata[0].options()); |
800 | auto srcFlatDataSplits = |
801 | srcFlatData.split_with_sizes(c10::IntArrayRef(send_lengthsL), 0); |
802 | for (const auto i : c10::irange(size_)) { |
803 | srcFlatDataSplits[i].copy_(srcdata[i].view({-1})); |
804 | } |
805 | c10::DeviceGuard guard1(srcdata[0].device()); |
806 | std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
807 | MPI_CHECK(MPI_Alltoallv( |
808 | srcFlatData.data_ptr(), |
809 | send_lengths.data(), |
810 | send_offsets.data(), |
811 | mpiDatatype.at(srcdata[0].scalar_type()), |
812 | dstFlatData.data_ptr(), |
813 | recv_lengths.data(), |
814 | recv_offsets.data(), |
815 | mpiDatatype.at(dstdata[0].scalar_type()), |
816 | pgComm_)); |
817 | |
818 | auto dstFlatDataSplits = |
819 | dstFlatData.split_with_sizes(c10::IntArrayRef(recv_lengthsL), 0); |
820 | for (const auto i : c10::irange(size_)) { |
821 | dstdata[i].view({-1}).copy_(dstFlatDataSplits[i]); |
822 | } |
823 | }; |
824 | auto entry = std::make_unique<WorkEntry>( |
825 | &inputTensors, &outputTensors, std::move(runFunc)); |
826 | return enqueue( |
827 | std::move(entry), |
828 | "mpi:all_to_all" , |
829 | c10::optional<std::vector<at::Tensor>>(inputTensors)); |
830 | } |
831 | |
832 | c10::intrusive_ptr<Work> ProcessGroupMPI::send( |
833 | std::vector<at::Tensor>& tensors, |
834 | int dstRank, |
835 | int tag) { |
836 | checkSingleTensor(tensors); |
837 | |
838 | auto& tensor = tensors[0]; |
839 | MPI_Request request = MPI_REQUEST_NULL; |
840 | |
841 | { |
842 | c10::DeviceGuard guard(tensor.device()); |
843 | std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
844 | MPI_CHECK(MPI_Isend( |
845 | tensor.data_ptr(), |
846 | tensor.numel(), |
847 | mpiDatatype.at(tensor.scalar_type()), |
848 | dstRank, |
849 | tag, |
850 | pgComm_, |
851 | &request)); |
852 | } |
853 | |
854 | return c10::make_intrusive<AsyncWork>( |
855 | request, |
856 | std::vector<at::Tensor>(), |
857 | "mpi:send" , |
858 | c10::optional<std::vector<at::Tensor>>(tensors)); |
859 | } |
860 | |
861 | c10::intrusive_ptr<Work> ProcessGroupMPI::recv( |
862 | std::vector<at::Tensor>& tensors, |
863 | int srcRank, |
864 | int tag) { |
865 | checkSingleTensor(tensors); |
866 | |
867 | auto& tensor = tensors[0]; |
868 | MPI_Request request = MPI_REQUEST_NULL; |
869 | |
870 | { |
871 | c10::DeviceGuard guard(tensor.device()); |
872 | std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
873 | MPI_CHECK(MPI_Irecv( |
874 | tensor.data_ptr(), |
875 | tensor.numel(), |
876 | mpiDatatype.at(tensor.scalar_type()), |
877 | srcRank, |
878 | tag, |
879 | pgComm_, |
880 | &request)); |
881 | } |
882 | |
883 | return c10::make_intrusive<AsyncWork>( |
884 | request, |
885 | tensors, |
886 | "mpi:recv" , |
887 | c10::optional<std::vector<at::Tensor>>(tensors)); |
888 | } |
889 | |
890 | c10::intrusive_ptr<Work> ProcessGroupMPI::recvAnysource( |
891 | std::vector<at::Tensor>& tensors, |
892 | int tag) { |
893 | checkSingleTensor(tensors); |
894 | |
895 | auto& tensor = tensors[0]; |
896 | MPI_Request request = MPI_REQUEST_NULL; |
897 | |
898 | { |
899 | c10::DeviceGuard guard(tensor.device()); |
900 | std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
901 | MPI_CHECK(MPI_Irecv( |
902 | tensor.data_ptr(), |
903 | tensor.numel(), |
904 | mpiDatatype.at(tensor.scalar_type()), |
905 | MPI_ANY_SOURCE, |
906 | tag, |
907 | pgComm_, |
908 | &request)); |
909 | } |
910 | |
911 | return c10::make_intrusive<AsyncWork>( |
912 | request, |
913 | tensors, |
914 | "mpi:recvAnySource" , |
915 | c10::optional<std::vector<at::Tensor>>(tensors)); |
916 | } |
917 | |
918 | c10::intrusive_ptr<Work> ProcessGroupMPI::barrier(const BarrierOptions& opts) { |
919 | std::function<void(std::unique_ptr<WorkEntry>&)> runFunc = |
920 | [this](std::unique_ptr<WorkEntry>& entry) { |
921 | std::unique_lock<std::mutex> globalLock(pgGlobalMutex_); |
922 | MPI_CHECK(MPI_Barrier(pgComm_)); |
923 | }; |
924 | auto entry = |
925 | std::make_unique<WorkEntry>(nullptr, nullptr, std::move(runFunc)); |
926 | return enqueue(std::move(entry), "mpi:barrier" , c10::nullopt); |
927 | } |
928 | |
929 | c10::intrusive_ptr<Work> ProcessGroupMPI::_allgather_base( |
930 | at::Tensor& /*unused */, |
931 | at::Tensor& /*unused */, |
932 | const AllgatherOptions& /*unused */) { |
933 | TORCH_CHECK(false, "no support for _allgather_base in MPI process group" ); |
934 | } |
935 | |
936 | } // namespace c10d |
937 | |
938 | #endif // USE_C10D_MPI |
939 | |