1#include <torch/csrc/distributed/c10d/ProcessGroupRoundRobin.hpp>
2
3namespace c10d {
4
5ProcessGroupRoundRobin::ProcessGroupRoundRobin(
6 int rank,
7 int size,
8 std::vector<c10::intrusive_ptr<ProcessGroup>> processGroups)
9 : ProcessGroup(rank, size), processGroups_(std::move(processGroups)) {
10 TORCH_WARN(
11 "ProcessGroupRoundRobin is deprecated and scheduled to be removed after this current release (1.13). ",
12 "Please file an issue on https://github.com/pytorch/pytorch/issues if there are any concerns or issues with this deprecation.");
13 TORCH_CHECK(!processGroups_.empty());
14 for (const auto& processGroup : processGroups_) {
15 TORCH_CHECK(processGroup->getRank() == rank_);
16 TORCH_CHECK(processGroup->getSize() == size_);
17 }
18 iterator_ = processGroups_.begin();
19}
20
21ProcessGroupRoundRobin::~ProcessGroupRoundRobin() = default;
22
23c10::intrusive_ptr<Work> ProcessGroupRoundRobin::broadcast(
24 std::vector<at::Tensor>& tensors,
25 const BroadcastOptions& opts) {
26 return next()->broadcast(tensors, opts);
27}
28
29c10::intrusive_ptr<Work> ProcessGroupRoundRobin::allreduce(
30 std::vector<at::Tensor>& tensors,
31 const AllreduceOptions& opts) {
32 return next()->allreduce(tensors, opts);
33}
34
35c10::intrusive_ptr<Work> ProcessGroupRoundRobin::allreduce_coalesced(
36 std::vector<at::Tensor>& tensors,
37 const AllreduceCoalescedOptions& opts) {
38 return next()->allreduce_coalesced(tensors, opts);
39}
40
41c10::intrusive_ptr<Work> ProcessGroupRoundRobin::reduce(
42 std::vector<at::Tensor>& tensors,
43 const ReduceOptions& opts) {
44 return next()->reduce(tensors, opts);
45}
46
47c10::intrusive_ptr<Work> ProcessGroupRoundRobin::allgather(
48 std::vector<std::vector<at::Tensor>>& outputs,
49 std::vector<at::Tensor>& inputs,
50 const AllgatherOptions& opts) {
51 return next()->allgather(outputs, inputs, opts);
52};
53
54c10::intrusive_ptr<Work> ProcessGroupRoundRobin::allgather_coalesced(
55 std::vector<std::vector<at::Tensor>>& outputTensorLists,
56 std::vector<at::Tensor>& inputTensors,
57 const AllgatherOptions& opts) {
58 return next()->allgather(outputTensorLists, inputTensors, opts);
59}
60
61c10::intrusive_ptr<Work> ProcessGroupRoundRobin::gather(
62 std::vector<std::vector<at::Tensor>>& outputs,
63 std::vector<at::Tensor>& inputs,
64 const GatherOptions& opts) {
65 return next()->gather(outputs, inputs, opts);
66};
67
68c10::intrusive_ptr<Work> ProcessGroupRoundRobin::scatter(
69 std::vector<at::Tensor>& outputs,
70 std::vector<std::vector<at::Tensor>>& inputs,
71 const ScatterOptions& opts) {
72 return next()->scatter(outputs, inputs, opts);
73};
74
75c10::intrusive_ptr<Work> ProcessGroupRoundRobin::reduce_scatter(
76 std::vector<at::Tensor>& outputs,
77 std::vector<std::vector<at::Tensor>>& inputs,
78 const ReduceScatterOptions& opts) {
79 return next()->reduce_scatter(outputs, inputs, opts);
80};
81
82c10::intrusive_ptr<Work> ProcessGroupRoundRobin::alltoall_base(
83 at::Tensor& outputTensor,
84 at::Tensor& inputTensor,
85 std::vector<int64_t>& outputSplitSizes,
86 std::vector<int64_t>& inputSplitSizes,
87 const AllToAllOptions& opts) {
88 return next()->alltoall_base(
89 outputTensor, inputTensor, outputSplitSizes, inputSplitSizes, opts);
90};
91
92c10::intrusive_ptr<Work> ProcessGroupRoundRobin::send(
93 std::vector<at::Tensor>& /* unused */,
94 int /* unused */,
95 int /* unused */) {
96 TORCH_CHECK(false, "ProcessGroupRoundRobin does not support send");
97};
98
99c10::intrusive_ptr<Work> ProcessGroupRoundRobin::recv(
100 std::vector<at::Tensor>& /* unused */,
101 int /* unused */,
102 int /* unused */) {
103 TORCH_CHECK(false, "ProcessGroupRoundRobin does not support recv");
104};
105
106c10::intrusive_ptr<Work> ProcessGroupRoundRobin::recvAnysource(
107 std::vector<at::Tensor>& /* unused */,
108 int /* unused */) {
109 TORCH_CHECK(false, "ProcessGroupRoundRobin does not support recv");
110};
111
112c10::intrusive_ptr<Work> ProcessGroupRoundRobin::barrier(
113 const BarrierOptions& /* unused */) {
114 TORCH_CHECK(false, "ProcessGroupRoundRobin does not support barrier");
115};
116
117const c10::intrusive_ptr<ProcessGroup>& ProcessGroupRoundRobin::next() {
118 auto& processGroup = *iterator_;
119 iterator_++;
120 if (iterator_ == processGroups_.end()) {
121 iterator_ = processGroups_.begin();
122 }
123 return processGroup;
124}
125
126c10::intrusive_ptr<Work> ProcessGroupRoundRobin::_allgather_base(
127 at::Tensor& /*unused */,
128 at::Tensor& /*unused */,
129 const AllgatherOptions& /*unused */) {
130 TORCH_CHECK(
131 false, "no support for _allgather_base in RoundRobin process group");
132}
133
134} // namespace c10d
135