1#pragma once
2
3#include <torch/csrc/distributed/c10d/Backend.hpp>
4#include <condition_variable>
5#include <memory>
6#include <mutex>
7#include <stdexcept>
8#include <unordered_map>
9#include <utility>
10#include <vector>
11
12#include <ATen/ATen.h>
13#include <ATen/core/dispatch/Dispatcher.h>
14#include <c10/macros/Macros.h>
15
16#include <torch/csrc/distributed/c10d/Work.hpp>
17// *************************************************************************
18// PROCESS GROUP collective communication API IS BEING CHANGED BETWEEN
19// versions 1.7 and 1.8.
20// PLEASE DO NOT ADD ANY DEPENDENCIES.
21// SEE RFC: https://github.com/pytorch/pytorch/issues/39662
22// *************************************************************************
23
24constexpr auto kProcessGroupDefaultTimeout =
25 std::chrono::milliseconds(30 * 60 * 1000);
26
27namespace c10d {
28
29// ProcessGroup is a base class that captures collective and point to
30// point communication in a fixed set of processes.
31//
32// The functions specified in the class below describe the API alone;
33// implementations are provided in subclasses.
34//
35// Every function that performs I/O is executed asynchronously by a
36// thread pool owned by the ProcessGroup (by default). They return an
37// object that can be used to wait for completion or error.
38//
39// The ProcessGroup can instantiate subgroups with fewer or an equal
40// number of members. Implementations must take care that multiple
41// process groups can be used in parallel and synchronize accordingly.
42//
43// The ProcessGroup assumes a fixed set of processes. If the set
44// changes, existing instances must be destructed and instantiation
45// and initialization must start from scratch. For members of the
46// process group to find each other (referred to as rendezvous from
47// hereon)
48//
49class TORCH_API ProcessGroup : public torch::CustomClassHolder {
50 public:
51 // ProcessGroup Options is a base struct that defines the basic options
52 // when constructing a ProcessGroup. Each ProcessGroup subclass should
53 // extend this struct and define its options if it wants to provide more
54 // config options (beyond basic ones defined here) to end user.
55 struct TORCH_API Options : torch::CustomClassHolder {
56 explicit Options(
57 std::string backend,
58 std::chrono::milliseconds timeout = kProcessGroupDefaultTimeout)
59 : timeout(timeout), backend(std::move(backend)) {}
60 ~Options() override = default;
61
62 std::chrono::milliseconds timeout;
63
64 // backend name
65 const std::string backend;
66 };
67
68 enum BackendType {
69 UNDEFINED = 0,
70 GLOO = 1,
71 NCCL = 2,
72 UCC = 3,
73 MPI = 4,
74 CUSTOM = 5,
75 };
76
77 // Not used, set for backwards compatibility and only used for TypeDef in
78 // Ops.cpp
79 explicit ProcessGroup(int rank, int size);
80
81 explicit ProcessGroup(
82 const c10::intrusive_ptr<::c10d::Store>& store,
83 int rank,
84 int size,
85 c10::intrusive_ptr<Options> options);
86 ~ProcessGroup() override;
87
88 int getRank() const {
89 return rank_;
90 }
91
92 int getSize() const {
93 return size_;
94 }
95
96 virtual const std::string getBackendName() const {
97 return options_->backend;
98 };
99
100 BackendType getBackendType() const {
101 return backendType_;
102 };
103
104 virtual void startCoalescing(c10::DeviceType deviceType) {
105 // only nccl has implemented startCoalescing so only execute for nccl
106 // backends
107 if (getBackendType() == BackendType::NCCL) {
108 getBackend(deviceType)->startCoalescing();
109 }
110 }
111
112 virtual void endCoalescing(
113 c10::DeviceType deviceType,
114 std::vector<c10::intrusive_ptr<Work>>& reqs) {
115 // only nccl has implemented startCoalescing so only execute for nccl
116 // backends
117 if (getBackendType() == BackendType::NCCL) {
118 getBackend(deviceType)->endCoalescing(reqs);
119 }
120 }
121
122 virtual c10::intrusive_ptr<Work> broadcast(
123 std::vector<at::Tensor>& tensors,
124 const BroadcastOptions& opts = BroadcastOptions()) {
125 static auto op =
126 c10::Dispatcher::singleton()
127 .findSchemaOrThrow("c10d::broadcast_", "")
128 .typed<
129 std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(
130 at::TensorList,
131 const c10::intrusive_ptr<::c10d::ProcessGroup>&,
132 int64_t,
133 int64_t,
134 int64_t)>();
135 // It's awakward to unbox the opts here and box them again in the custom C++
136 // op. But it's also complicated to make opts as a CustomClassHolder. Leave
137 // it as it is now.
138 return std::get<1>(op.call(
139 tensors,
140 c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
141 opts.rootRank,
142 opts.rootTensor,
143 opts.timeout.count()));
144 }
145
146 virtual c10::intrusive_ptr<Work> allreduce(
147 std::vector<at::Tensor>& tensors,
148 const AllreduceOptions& opts = AllreduceOptions()) {
149 static auto op =
150 c10::Dispatcher::singleton()
151 .findSchemaOrThrow("c10d::allreduce_", "")
152 .typed<
153 std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(
154 at::TensorList,
155 const c10::intrusive_ptr<::c10d::ProcessGroup>&,
156 const c10::intrusive_ptr<::c10d::ReduceOp>&,
157 int64_t)>();
158
159 return std::get<1>(op.call(
160 tensors,
161 c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
162 c10::make_intrusive<ReduceOp>(opts.reduceOp),
163 opts.timeout.count()));
164 }
165
166 virtual c10::intrusive_ptr<Work> allreduce_coalesced(
167 std::vector<at::Tensor>& tensors,
168 const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) {
169 static auto op = c10::Dispatcher::singleton()
170 .findSchemaOrThrow("c10d::allreduce_coalesced_", "")
171 .typed<c10::intrusive_ptr<::c10d::Work>(
172 at::TensorList,
173 const c10::intrusive_ptr<::c10d::ProcessGroup>&,
174 const c10::intrusive_ptr<::c10d::ReduceOp>&,
175 int64_t)>();
176
177 return op.call(
178 tensors,
179 c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
180
181 c10::make_intrusive<ReduceOp>(opts.reduceOp),
182 opts.timeout.count());
183 }
184
185 virtual c10::intrusive_ptr<Work> reduce(
186 std::vector<at::Tensor>& tensors,
187 const ReduceOptions& opts = ReduceOptions()) {
188 static auto op = c10::Dispatcher::singleton()
189 .findSchemaOrThrow("c10d::reduce_", "")
190 .typed<c10::intrusive_ptr<::c10d::Work>(
191 at::TensorList,
192 const c10::intrusive_ptr<::c10d::ProcessGroup>&,
193 const c10::intrusive_ptr<::c10d::ReduceOp>&,
194 int64_t,
195 int64_t,
196 int64_t)>();
197 return op.call(
198 tensors,
199 c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
200
201 c10::make_intrusive<ReduceOp>(opts.reduceOp),
202 opts.rootRank,
203 opts.rootTensor,
204 opts.timeout.count());
205 }
206
207 virtual c10::intrusive_ptr<Work> allgather(
208 std::vector<std::vector<at::Tensor>>& outputTensors,
209 std::vector<at::Tensor>& inputTensors,
210 const AllgatherOptions& opts = AllgatherOptions()) {
211 static auto op = c10::Dispatcher::singleton()
212 .findSchemaOrThrow("c10d::allgather_", "")
213 .typed<std::tuple<
214 std::vector<std::vector<at::Tensor>>,
215 c10::intrusive_ptr<Work>>(
216 const std::vector<std::vector<at::Tensor>>&,
217 at::TensorList,
218 const c10::intrusive_ptr<::c10d::ProcessGroup>&,
219 int64_t)>();
220
221 return std::get<1>(op.call(
222 outputTensors,
223 inputTensors,
224 c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
225 opts.timeout.count()));
226 }
227
228 // Gathers a single tensor inputBuffer into a single buffer outputBuffer that
229 // is interpreted as a contigious collection of size inputBuffer * WORLD_SIZE.
230 // For implementers of ProcessGroup API and advanced users only.
231 // Note: this function will be deprecated in near future.
232 virtual c10::intrusive_ptr<Work> _allgather_base(
233 at::Tensor& outputBuffer,
234 at::Tensor& inputBuffer,
235 const AllgatherOptions& opts = AllgatherOptions()) {
236 static auto op =
237 c10::Dispatcher::singleton()
238 .findSchemaOrThrow("c10d::_allgather_base_", "")
239 .typed<std::tuple<at::Tensor, c10::intrusive_ptr<Work>>(
240 at::Tensor&,
241 at::Tensor&,
242 const c10::intrusive_ptr<::c10d::ProcessGroup>&)>();
243
244 return std::get<1>(op.call(
245 outputBuffer,
246 inputBuffer,
247 c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this)));
248 }
249
250 // This function is deprecated and will be moved out of ProcessGroup to comms:
251 // * do not add dependencies on this function,
252 // * do not implement it in your ProcessGroup, implement _allgather_base
253 // instead.
254 virtual c10::intrusive_ptr<Work> allgather_coalesced(
255 std::vector<std::vector<at::Tensor>>& outputTensorLists,
256 std::vector<at::Tensor>& inputTensors,
257 const AllgatherOptions& opts = AllgatherOptions()) {
258 static auto op =
259 c10::Dispatcher::singleton()
260 .findSchemaOrThrow("c10d::allgather_coalesced_", "")
261 .typed<c10::intrusive_ptr<Work>(
262 const std::vector<std::vector<at::Tensor>>&,
263 const at::TensorList&,
264 const c10::intrusive_ptr<::c10d::ProcessGroup>&)>();
265
266 return op.call(
267 outputTensorLists,
268 inputTensors,
269 c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this));
270 }
271
272 virtual c10::intrusive_ptr<Work> gather(
273 std::vector<std::vector<at::Tensor>>& outputTensors,
274 std::vector<at::Tensor>& inputTensors,
275 const GatherOptions& opts = GatherOptions()) {
276 static auto op = c10::Dispatcher::singleton()
277 .findSchemaOrThrow("c10d::gather_", "")
278 .typed<c10::intrusive_ptr<::c10d::Work>(
279 const std::vector<std::vector<at::Tensor>>&,
280 const at::TensorList&,
281 const c10::intrusive_ptr<::c10d::ProcessGroup>&,
282 int64_t,
283 int64_t)>();
284 return op.call(
285 outputTensors,
286 inputTensors,
287 c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
288 opts.rootRank,
289 opts.timeout.count());
290 }
291
292 virtual c10::intrusive_ptr<Work> scatter(
293 std::vector<at::Tensor>& outputTensors,
294 std::vector<std::vector<at::Tensor>>& inputTensors,
295 const ScatterOptions& opts = ScatterOptions()) {
296 static auto op =
297 c10::Dispatcher::singleton()
298 .findSchemaOrThrow("c10d::scatter_", "")
299 .typed<
300 std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(
301 const at::TensorList&,
302 const std::vector<std::vector<at::Tensor>>&,
303 const c10::intrusive_ptr<::c10d::ProcessGroup>&,
304 int64_t,
305 int64_t)>();
306 return std::get<1>(op.call(
307 outputTensors,
308 inputTensors,
309 c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
310 opts.rootRank,
311 opts.timeout.count()));
312 }
313
314 virtual c10::intrusive_ptr<Work> reduce_scatter(
315 std::vector<at::Tensor>& outputTensors,
316 std::vector<std::vector<at::Tensor>>& inputTensors,
317 const ReduceScatterOptions& opts = ReduceScatterOptions()) {
318 static auto op =
319 c10::Dispatcher::singleton()
320 .findSchemaOrThrow("c10d::reduce_scatter_", "")
321 .typed<
322 std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(
323 const at::TensorList&,
324 const std::vector<std::vector<at::Tensor>>&,
325 const c10::intrusive_ptr<::c10d::ProcessGroup>&,
326 const c10::intrusive_ptr<::c10d::ReduceOp>&,
327 int64_t)>();
328 return std::get<1>(op.call(
329 outputTensors,
330 inputTensors,
331 c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
332 c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
333 opts.timeout.count()));
334 }
335
336 virtual c10::intrusive_ptr<Work> _reduce_scatter_base(
337 at::Tensor& outputBuffer,
338 at::Tensor& inputBuffer,
339 const ReduceScatterOptions& opts = ReduceScatterOptions()) {
340 static auto op = c10::Dispatcher::singleton()
341 .findSchemaOrThrow("c10d::_reduce_scatter_base_", "")
342 .typed<std::tuple<at::Tensor, c10::intrusive_ptr<Work>>(
343 at::Tensor&,
344 at::Tensor&,
345 const c10::intrusive_ptr<::c10d::ProcessGroup>&,
346 const c10::intrusive_ptr<::c10d::ReduceOp>&,
347 int64_t)>();
348 return std::get<1>(op.call(
349 outputBuffer,
350 inputBuffer,
351 c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
352 c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp),
353 opts.timeout.count()));
354 }
355
356 virtual c10::intrusive_ptr<Work> alltoall_base(
357 at::Tensor& outputBuffer,
358 at::Tensor& inputBuffer,
359 std::vector<int64_t>& outputSplitSizes,
360 std::vector<int64_t>& inputSplitSizes,
361 const AllToAllOptions& opts = AllToAllOptions()) {
362 static auto op = c10::Dispatcher::singleton()
363 .findSchemaOrThrow("c10d::alltoall_base_", "")
364 .typed<c10::intrusive_ptr<::c10d::Work>(
365 at::Tensor&,
366 at::Tensor&,
367 const c10::intrusive_ptr<::c10d::ProcessGroup>&,
368 std::vector<int64_t>,
369 std::vector<int64_t>,
370 int64_t)>();
371 return op.call(
372 outputBuffer,
373 inputBuffer,
374 c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
375 outputSplitSizes,
376 inputSplitSizes,
377 opts.timeout.count());
378 }
379
380 virtual c10::intrusive_ptr<Work> alltoall(
381 std::vector<at::Tensor>& outputTensors,
382 std::vector<at::Tensor>& inputTensors,
383 const AllToAllOptions& opts = AllToAllOptions()) {
384 static auto op = c10::Dispatcher::singleton()
385 .findSchemaOrThrow("c10d::alltoall_", "")
386 .typed<std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(
387 const at::TensorList&,
388 const at::TensorList&,
389 const c10::intrusive_ptr<::c10d::ProcessGroup>&,
390 int64_t)>();
391 return std::get<1>(op.call(
392 outputTensors,
393 inputTensors,
394 c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
395 opts.timeout.count()));
396 }
397
398 virtual void monitoredBarrier(
399 const BarrierOptions& opts,
400 bool wait_all_ranks = false) {
401 static auto op = c10::Dispatcher::singleton()
402 .findSchemaOrThrow("c10d::monitored_barrier_", "")
403 .typed<void(
404 at::Tensor,
405 const c10::intrusive_ptr<::c10d::ProcessGroup>&,
406 const std::vector<int64_t>&,
407 int64_t,
408 bool)>();
409 // Default to using cpu implementation, monitored barrier is only for GLOO
410 at::Tensor tensor = at::empty({0}, at::TensorOptions().device(at::kCPU));
411 op.call(
412 tensor,
413 c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
414 opts.device_ids,
415 opts.timeout.count(),
416 wait_all_ranks);
417 }
418
419 // Agrees on an initial sequence number for the whole group by having rank 0
420 // create it and broadcast it to other ranks using the store. Only implemented
421 // for GLOO and NCCL backends currently.
422 virtual void setSequenceNumberForGroup() {
423 auto backendType = getBackendType();
424 // TODO: HACK for backend name to get sequence number for that backend.
425 if (backendType == ProcessGroup::BackendType::GLOO ||
426 backendType == ProcessGroup::BackendType::NCCL ||
427 backendType == ProcessGroup::BackendType::UCC) {
428 getDefaultBackend()->setSequenceNumberForGroup();
429 } else {
430 TORCH_CHECK(
431 false,
432 c10::str(
433 "ProcessGroup ",
434 getBackendName(),
435 " does not yet support sequence numbers."));
436 }
437 }
438
439 // Retrieves the current sequence number for the whole group, which should be
440 // in sync. If the returned number is not consistent across the group, it
441 // may indicate that there is some sort of collective desynchronization.
442 virtual uint64_t getSequenceNumberForGroup() {
443 auto backendType = getBackendType();
444
445 // TODO: HACK for backend name to get sequence number for that backend.
446 if (backendType == ProcessGroup::BackendType::GLOO ||
447 backendType == ProcessGroup::BackendType::NCCL ||
448 backendType == ProcessGroup::BackendType::UCC) {
449 return getDefaultBackend()->getSequenceNumberForGroup();
450 } else {
451 TORCH_CHECK(
452 false,
453 c10::str(
454 "ProcessGroup ",
455 getBackendName(),
456 " does not yet support sequence numbers."));
457 }
458 }
459
460 virtual c10::intrusive_ptr<Work> send(
461 std::vector<at::Tensor>& tensors,
462 int dstRank,
463 int tag) {
464 static auto op = c10::Dispatcher::singleton()
465 .findSchemaOrThrow("c10d::send", "")
466 .typed<c10::intrusive_ptr<::c10d::Work>(
467 at::TensorList,
468 const c10::intrusive_ptr<::c10d::ProcessGroup>&,
469 int64_t,
470 int64_t)>();
471 return op.call(
472 tensors,
473 c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
474 dstRank,
475 tag);
476 }
477
478 virtual c10::intrusive_ptr<Work> recv(
479 std::vector<at::Tensor>& tensors,
480 int srcRank,
481 int tag) {
482 static auto op = c10::Dispatcher::singleton()
483 .findSchemaOrThrow("c10d::recv_", "")
484 .typed<c10::intrusive_ptr<::c10d::Work>(
485 at::TensorList,
486 const c10::intrusive_ptr<::c10d::ProcessGroup>&,
487 int64_t,
488 int64_t)>();
489 return op.call(
490 tensors,
491 c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
492 srcRank,
493 tag);
494 }
495
496 virtual c10::intrusive_ptr<Work> recvAnysource(
497 std::vector<at::Tensor>& tensors,
498 int tag) {
499 static auto op = c10::Dispatcher::singleton()
500 .findSchemaOrThrow("c10d::recv_any_source_", "")
501 .typed<c10::intrusive_ptr<::c10d::Work>(
502 at::TensorList,
503 const c10::intrusive_ptr<::c10d::ProcessGroup>&,
504 int64_t)>();
505 return op.call(
506 tensors,
507 c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
508 tag);
509 }
510
511 virtual c10::intrusive_ptr<Work> barrier(
512 const BarrierOptions& opts = BarrierOptions()) {
513 static at::Tensor tensor;
514 // TODO: if nccl was specified then use it
515 if (backendType_ == c10d::ProcessGroup::BackendType::NCCL) {
516 // set cuda tensor
517 tensor = at::empty(
518 {1},
519 at::TensorOptions().device(at::DeviceType::CUDA).dtype(at::kByte));
520 } else {
521 // Default to using cpu implementation
522 tensor = at::empty(
523 {1},
524 at::TensorOptions().device(at::DeviceType::CPU).dtype(at::kByte));
525 }
526
527 static auto op = c10::Dispatcher::singleton()
528 .findSchemaOrThrow("c10d::barrier", "")
529 .typed<c10::intrusive_ptr<::c10d::Work>(
530 at::Tensor,
531 const c10::intrusive_ptr<::c10d::ProcessGroup>&,
532 const std::vector<int64_t>&,
533 int64_t)>();
534
535 return op.call(
536 tensor,
537 c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning(this),
538 opts.device_ids,
539 opts.timeout.count());
540 }
541
542 c10::intrusive_ptr<Options> getOptions() {
543 return options_;
544 }
545
546 bool hasBackends() {
547 return !deviceTypeToBackendType_.empty();
548 }
549
550 void setBackend(
551 c10::DeviceType deviceType,
552 BackendType backendType,
553 const c10::optional<c10::intrusive_ptr<Backend>>& backend) {
554 deviceTypeToBackendType_[deviceType] = backendType;
555 // if the backendType is already set then reuse it for this device
556 if (backendTypeToBackend_.find(backendType) !=
557 backendTypeToBackend_.end()) {
558 auto existingBackend = backendTypeToBackend_.at(backendType);
559 deviceTypeToBackend_[deviceType] = existingBackend;
560 } else {
561 // check if backend has value
562 if (backend.has_value()) {
563 deviceTypeToBackend_[deviceType] = backend.value();
564 backendTypeToBackend_[backendType] = backend.value();
565 }
566 }
567 }
568
569 c10::intrusive_ptr<Backend> getDefaultBackend() const {
570 TORCH_CHECK(
571 backendTypeToBackend_.find(backendType_) != backendTypeToBackend_.end(),
572 "Could not find the default backend type ",
573 backendType_,
574 " for Process Group with name ",
575 getBackendName(),
576 ".");
577 return backendTypeToBackend_.at(backendType_);
578 }
579
580 c10::intrusive_ptr<Backend> getBackend(c10::DeviceType deviceType);
581
582 c10::intrusive_ptr<Backend> getBackend(BackendType backendType) const {
583 TORCH_CHECK(
584 backendTypeToBackend_.find(backendType) != backendTypeToBackend_.end(),
585 "Could not find backend type ",
586 backendType,
587 ".");
588 return backendTypeToBackend_.at(backendType);
589 }
590
591 protected:
592 // Implementations of this interface need to call this to setup
593 // appropriate logging etc.
594 void init();
595
596 const c10::intrusive_ptr<c10d::Store> store_;
597 const int rank_;
598 const int size_;
599 const c10::intrusive_ptr<Options> options_;
600 const BackendType backendType_;
601 // Optional sequence number structure for matching collectives.
602 c10::optional<c10d::SequenceNum> sequenceNum_ = c10::nullopt;
603
604 // Debug level setting. It is parsed once when ProcessGroup is constructed and
605 // remains the same across use of this process group.
606 DebugLevel dist_debug_level_;
607
608 // Backend classes for this ProcessGroup
609 std::unordered_map<c10::DeviceType, BackendType> deviceTypeToBackendType_;
610 std::unordered_map<c10::DeviceType, c10::intrusive_ptr<Backend>>
611 deviceTypeToBackend_;
612 std::unordered_map<BackendType, c10::intrusive_ptr<Backend>>
613 backendTypeToBackend_;
614};
615
616} // namespace c10d
617