1 | #pragma once |
2 | |
3 | #ifdef USE_C10D_GLOO |
4 | |
5 | #include <torch/csrc/distributed/c10d/ProcessGroupGloo.hpp> |
6 | #include <torch/csrc/distributed/c10d/Types.hpp> |
7 | #include <torch/csrc/distributed/c10d/Utils.hpp> |
8 | |
9 | namespace c10d { |
10 | |
11 | class TORCH_API ProcessGroupWrapper : public Backend { |
12 | public: |
13 | explicit ProcessGroupWrapper( |
14 | c10::intrusive_ptr<Backend> backend, |
15 | c10::intrusive_ptr<Backend> glooBackend); |
16 | |
17 | const std::string getBackendName() const override; |
18 | |
19 | c10::intrusive_ptr<Work> broadcast( |
20 | std::vector<at::Tensor>& data, |
21 | const BroadcastOptions& opts = BroadcastOptions()) override; |
22 | |
23 | c10::intrusive_ptr<Work> allreduce( |
24 | std::vector<at::Tensor>& data, |
25 | const AllreduceOptions& opts = AllreduceOptions()) override; |
26 | |
27 | c10::intrusive_ptr<Work> allreduce_coalesced( |
28 | std::vector<at::Tensor>& tensors, |
29 | const AllreduceCoalescedOptions& opts = |
30 | AllreduceCoalescedOptions()) override; |
31 | |
32 | c10::intrusive_ptr<Work> reduce( |
33 | std::vector<at::Tensor>& tensors, |
34 | const ReduceOptions& opts = ReduceOptions()) override; |
35 | |
36 | c10::intrusive_ptr<Work> allgather( |
37 | std::vector<std::vector<at::Tensor>>& outputTensors, |
38 | std::vector<at::Tensor>& inputTensors, |
39 | const AllgatherOptions& opts = AllgatherOptions()) override; |
40 | |
41 | c10::intrusive_ptr<Work> _allgather_base( |
42 | at::Tensor& outputBuffer, |
43 | at::Tensor& inputBuffer, |
44 | const AllgatherOptions& opts = AllgatherOptions()) override; |
45 | |
46 | // This function is deprecated and will be moved out of ProcessGroup to comms: |
47 | // * do not add dependencies on this function, |
48 | // * do not implement it in your ProcessGroup, implement _allgather_base |
49 | // instead. |
50 | c10::intrusive_ptr<Work> allgather_coalesced( |
51 | std::vector<std::vector<at::Tensor>>& outputTensorLists, |
52 | std::vector<at::Tensor>& inputTensors, |
53 | const AllgatherOptions& opts = AllgatherOptions()) override; |
54 | |
55 | c10::intrusive_ptr<Work> gather( |
56 | std::vector<std::vector<at::Tensor>>& outputTensors, |
57 | std::vector<at::Tensor>& inputTensors, |
58 | const GatherOptions& opts = GatherOptions()) override; |
59 | |
60 | c10::intrusive_ptr<Work> scatter( |
61 | std::vector<at::Tensor>& outputTensors, |
62 | std::vector<std::vector<at::Tensor>>& inputTensors, |
63 | const ScatterOptions& opts = ScatterOptions()) override; |
64 | |
65 | c10::intrusive_ptr<Work> reduce_scatter( |
66 | std::vector<at::Tensor>& outputTensors, |
67 | std::vector<std::vector<at::Tensor>>& inputTensors, |
68 | const ReduceScatterOptions& opts = ReduceScatterOptions()) override; |
69 | |
70 | c10::intrusive_ptr<Work> alltoall_base( |
71 | at::Tensor& outputTensor, |
72 | at::Tensor& inputTensor, |
73 | std::vector<int64_t>& outputSplitSizes, |
74 | std::vector<int64_t>& inputSplitSizes, |
75 | const AllToAllOptions& opts = AllToAllOptions()) override; |
76 | |
77 | c10::intrusive_ptr<Work> alltoall( |
78 | std::vector<at::Tensor>& outputTensors, |
79 | std::vector<at::Tensor>& inputTensors, |
80 | const AllToAllOptions& opts = AllToAllOptions()) override; |
81 | |
82 | void monitoredBarrier(const BarrierOptions& opts, bool waitAllRanks = false) |
83 | override; |
84 | |
85 | // Agrees on an initial sequence number for the whole group by having rank 0 |
86 | // create it and broadcast it to other ranks using the store. Only implemented |
87 | // for GLOO and NCCL backends currently. |
88 | // dont implement this |
89 | void setSequenceNumberForGroup() override; |
90 | |
91 | // Retrieves the current sequence number for the whole group, which should be |
92 | // in sync. If the returned number is not consistent across the group, it |
93 | // may indicate that there is some sort of collective desynchronization. |
94 | uint64_t getSequenceNumberForGroup() override; // just call underlying |
95 | |
96 | c10::intrusive_ptr<Work> send( |
97 | std::vector<at::Tensor>& tensors, |
98 | int dstRank, |
99 | int tag) override; |
100 | |
101 | c10::intrusive_ptr<Work> recv( |
102 | std::vector<at::Tensor>& tensors, |
103 | int srcRank, |
104 | int tag) override; |
105 | |
106 | c10::intrusive_ptr<Work> recvAnysource( |
107 | std::vector<at::Tensor>& tensors, |
108 | int tag) override; |
109 | |
110 | c10::intrusive_ptr<Work> barrier( |
111 | const BarrierOptions& opts = BarrierOptions()) override; |
112 | |
113 | c10::intrusive_ptr<Work> _reduce_scatter_base( |
114 | at::Tensor& outputBuffer, |
115 | at::Tensor& inputBuffer, |
116 | const ReduceScatterOptions& opts) override; |
117 | |
118 | c10::intrusive_ptr<Backend> getWrappedPg() const; |
119 | |
120 | private: |
121 | // Underlying process group that actual application collectives will be |
122 | // dispatched to |
123 | c10::intrusive_ptr<Backend> backend_; |
124 | // Gloo process group responsible for internal coordination such as monitored |
125 | // barrier, sequence number checking, collective fingerprint collecting. |
126 | c10::intrusive_ptr<Backend> glooBackend_; |
127 | // Conducts several checks to ensure that the underlying collective is well |
128 | // formed with the goal of notifying the user about incorrect collective use |
129 | // in the application. |
130 | void runCollectiveChecks( |
131 | OpType op_type, |
132 | const std::vector<at::Tensor>& tensors) const; |
133 | }; |
134 | } // namespace c10d |
135 | |
136 | #endif // USE_C10D_GLOO |
137 | |