1#pragma once
2
3#include <vector>
4
5#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
6
7namespace c10d {
8
9constexpr const char* ROUND_ROBIN_BACKEND_NAME = "round_robin";
10
11// ProcessGroupRoundRobin implements simple load balancing.
12//
13// It is constructed with multiple processes groups. Each call is dispatched to
14// one of the specified process groups in a round robin fashion. Each process
15// group instance must have the same rank and size.
16//
17// All functions of the class are expected to be called in the same order
18// across all processes in the process group. This is the only way that we
19// can guarantee to match up the same calls among all processes.
20//
21class TORCH_API ProcessGroupRoundRobin final : public ProcessGroup {
22 public:
23 explicit ProcessGroupRoundRobin(
24 int rank,
25 int size,
26 std::vector<c10::intrusive_ptr<ProcessGroup>> processGroups);
27
28 ~ProcessGroupRoundRobin() override;
29
30 const std::string getBackendName() const override {
31 return std::string(ROUND_ROBIN_BACKEND_NAME);
32 }
33
34 c10::intrusive_ptr<Work> broadcast(
35 std::vector<at::Tensor>& tensors,
36 const BroadcastOptions& opts = BroadcastOptions()) override;
37
38 c10::intrusive_ptr<Work> allreduce(
39 std::vector<at::Tensor>& tensors,
40 const AllreduceOptions& opts = AllreduceOptions()) override;
41
42 c10::intrusive_ptr<Work> allreduce_coalesced(
43 std::vector<at::Tensor>& tensors,
44 const AllreduceCoalescedOptions& opts =
45 AllreduceCoalescedOptions()) override;
46
47 c10::intrusive_ptr<Work> reduce(
48 std::vector<at::Tensor>& tensors,
49 const ReduceOptions& opts = ReduceOptions()) override;
50
51 c10::intrusive_ptr<Work> allgather(
52 std::vector<std::vector<at::Tensor>>& outputs,
53 std::vector<at::Tensor>& inputs,
54 const AllgatherOptions& opts = AllgatherOptions()) override;
55
56 c10::intrusive_ptr<Work> _allgather_base(
57 at::Tensor& outputBuffer,
58 at::Tensor& inputBuffer,
59 const AllgatherOptions& opts = AllgatherOptions()) override;
60
61 c10::intrusive_ptr<Work> allgather_coalesced(
62 std::vector<std::vector<at::Tensor>>& outputTensorLists,
63 std::vector<at::Tensor>& inputTensors,
64 const AllgatherOptions& opts = AllgatherOptions()) override;
65
66 c10::intrusive_ptr<Work> gather(
67 std::vector<std::vector<at::Tensor>>& outputs,
68 std::vector<at::Tensor>& inputs,
69 const GatherOptions& opts = GatherOptions()) override;
70
71 c10::intrusive_ptr<Work> scatter(
72 std::vector<at::Tensor>& outputs,
73 std::vector<std::vector<at::Tensor>>& inputs,
74 const ScatterOptions& opts = ScatterOptions()) override;
75
76 c10::intrusive_ptr<Work> reduce_scatter(
77 std::vector<at::Tensor>& outputs,
78 std::vector<std::vector<at::Tensor>>& inputs,
79 const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
80
81 c10::intrusive_ptr<Work> alltoall_base(
82 at::Tensor& outputTensor,
83 at::Tensor& inputTensor,
84 std::vector<int64_t>& outputSplitSizes,
85 std::vector<int64_t>& inputSplitSizes,
86 const AllToAllOptions& opts = AllToAllOptions()) override;
87
88 c10::intrusive_ptr<Work> send(
89 std::vector<at::Tensor>& tensors,
90 int dstRank,
91 int tag) override;
92
93 c10::intrusive_ptr<Work> recv(
94 std::vector<at::Tensor>& tensors,
95 int srcRank,
96 int tag) override;
97
98 c10::intrusive_ptr<Work> recvAnysource(
99 std::vector<at::Tensor>& tensors,
100 int tag) override;
101
102 c10::intrusive_ptr<Work> barrier(
103 const BarrierOptions& opts = BarrierOptions()) override;
104
105 private:
106 std::vector<c10::intrusive_ptr<ProcessGroup>> processGroups_;
107 std::vector<c10::intrusive_ptr<ProcessGroup>>::const_iterator iterator_;
108
109 // Returns the next ProcessGroup to use.
110 const c10::intrusive_ptr<ProcessGroup>& next();
111};
112
113} // namespace c10d
114