1 | #pragma once |
2 | |
3 | #include <vector> |
4 | |
5 | #include <torch/csrc/distributed/c10d/ProcessGroup.hpp> |
6 | |
7 | namespace c10d { |
8 | |
9 | constexpr const char* ROUND_ROBIN_BACKEND_NAME = "round_robin" ; |
10 | |
11 | // ProcessGroupRoundRobin implements simple load balancing. |
12 | // |
13 | // It is constructed with multiple processes groups. Each call is dispatched to |
14 | // one of the specified process groups in a round robin fashion. Each process |
15 | // group instance must have the same rank and size. |
16 | // |
17 | // All functions of the class are expected to be called in the same order |
18 | // across all processes in the process group. This is the only way that we |
19 | // can guarantee to match up the same calls among all processes. |
20 | // |
21 | class TORCH_API ProcessGroupRoundRobin final : public ProcessGroup { |
22 | public: |
23 | explicit ProcessGroupRoundRobin( |
24 | int rank, |
25 | int size, |
26 | std::vector<c10::intrusive_ptr<ProcessGroup>> processGroups); |
27 | |
28 | ~ProcessGroupRoundRobin() override; |
29 | |
30 | const std::string getBackendName() const override { |
31 | return std::string(ROUND_ROBIN_BACKEND_NAME); |
32 | } |
33 | |
34 | c10::intrusive_ptr<Work> broadcast( |
35 | std::vector<at::Tensor>& tensors, |
36 | const BroadcastOptions& opts = BroadcastOptions()) override; |
37 | |
38 | c10::intrusive_ptr<Work> allreduce( |
39 | std::vector<at::Tensor>& tensors, |
40 | const AllreduceOptions& opts = AllreduceOptions()) override; |
41 | |
42 | c10::intrusive_ptr<Work> allreduce_coalesced( |
43 | std::vector<at::Tensor>& tensors, |
44 | const AllreduceCoalescedOptions& opts = |
45 | AllreduceCoalescedOptions()) override; |
46 | |
47 | c10::intrusive_ptr<Work> reduce( |
48 | std::vector<at::Tensor>& tensors, |
49 | const ReduceOptions& opts = ReduceOptions()) override; |
50 | |
51 | c10::intrusive_ptr<Work> allgather( |
52 | std::vector<std::vector<at::Tensor>>& outputs, |
53 | std::vector<at::Tensor>& inputs, |
54 | const AllgatherOptions& opts = AllgatherOptions()) override; |
55 | |
56 | c10::intrusive_ptr<Work> _allgather_base( |
57 | at::Tensor& outputBuffer, |
58 | at::Tensor& inputBuffer, |
59 | const AllgatherOptions& opts = AllgatherOptions()) override; |
60 | |
61 | c10::intrusive_ptr<Work> allgather_coalesced( |
62 | std::vector<std::vector<at::Tensor>>& outputTensorLists, |
63 | std::vector<at::Tensor>& inputTensors, |
64 | const AllgatherOptions& opts = AllgatherOptions()) override; |
65 | |
66 | c10::intrusive_ptr<Work> gather( |
67 | std::vector<std::vector<at::Tensor>>& outputs, |
68 | std::vector<at::Tensor>& inputs, |
69 | const GatherOptions& opts = GatherOptions()) override; |
70 | |
71 | c10::intrusive_ptr<Work> scatter( |
72 | std::vector<at::Tensor>& outputs, |
73 | std::vector<std::vector<at::Tensor>>& inputs, |
74 | const ScatterOptions& opts = ScatterOptions()) override; |
75 | |
76 | c10::intrusive_ptr<Work> reduce_scatter( |
77 | std::vector<at::Tensor>& outputs, |
78 | std::vector<std::vector<at::Tensor>>& inputs, |
79 | const ReduceScatterOptions& opts = ReduceScatterOptions()) override; |
80 | |
81 | c10::intrusive_ptr<Work> alltoall_base( |
82 | at::Tensor& outputTensor, |
83 | at::Tensor& inputTensor, |
84 | std::vector<int64_t>& outputSplitSizes, |
85 | std::vector<int64_t>& inputSplitSizes, |
86 | const AllToAllOptions& opts = AllToAllOptions()) override; |
87 | |
88 | c10::intrusive_ptr<Work> send( |
89 | std::vector<at::Tensor>& tensors, |
90 | int dstRank, |
91 | int tag) override; |
92 | |
93 | c10::intrusive_ptr<Work> recv( |
94 | std::vector<at::Tensor>& tensors, |
95 | int srcRank, |
96 | int tag) override; |
97 | |
98 | c10::intrusive_ptr<Work> recvAnysource( |
99 | std::vector<at::Tensor>& tensors, |
100 | int tag) override; |
101 | |
102 | c10::intrusive_ptr<Work> barrier( |
103 | const BarrierOptions& opts = BarrierOptions()) override; |
104 | |
105 | private: |
106 | std::vector<c10::intrusive_ptr<ProcessGroup>> processGroups_; |
107 | std::vector<c10::intrusive_ptr<ProcessGroup>>::const_iterator iterator_; |
108 | |
109 | // Returns the next ProcessGroup to use. |
110 | const c10::intrusive_ptr<ProcessGroup>& next(); |
111 | }; |
112 | |
113 | } // namespace c10d |
114 | |