1 | #pragma once |
2 | |
3 | #include <condition_variable> |
4 | #include <memory> |
5 | #include <mutex> |
6 | #include <stdexcept> |
7 | #include <unordered_map> |
8 | #include <utility> |
9 | #include <vector> |
10 | |
11 | #include <ATen/ATen.h> |
12 | #include <c10/macros/Macros.h> |
13 | |
14 | #include <torch/csrc/distributed/c10d/Work.hpp> |
15 | #include <torch/csrc/distributed/c10d/Types.hpp> |
16 | #include <torch/csrc/distributed/c10d/Utils.hpp> |
17 | #include <torch/csrc/distributed/c10d/debug.h> |
18 | #include <torch/csrc/distributed/c10d/sequence_num.hpp> |
19 | |
20 | constexpr auto kBackendDefaultTimeout = |
21 | std::chrono::milliseconds(30 * 60 * 1000); |
22 | |
23 | namespace c10d { |
24 | |
25 | class TORCH_API Backend : public torch::CustomClassHolder { |
26 | public: |
27 | |
28 | // Backend Options is a base struct that defines the basic options |
29 | // when constructing a Backend. Each Backend subclass should |
30 | // extend this struct and define its options if it wants to provide more |
31 | // config options (beyond basic ones defined here) to end user. |
32 | struct TORCH_API Options : torch::CustomClassHolder { |
33 | explicit Options( |
34 | std::string backend, |
35 | std::chrono::milliseconds timeout = kBackendDefaultTimeout) |
36 | : timeout(timeout), backend(std::move(backend)) {} |
37 | ~Options() override = default; |
38 | |
39 | std::chrono::milliseconds timeout; |
40 | |
41 | // backend name |
42 | const std::string backend; |
43 | }; |
44 | |
45 | explicit Backend(int rank, int size); |
46 | ~Backend() override = 0; |
47 | |
48 | int getRank() const { |
49 | return rank_; |
50 | } |
51 | |
52 | int getSize() const { |
53 | return size_; |
54 | } |
55 | |
56 | virtual void startCoalescing() { |
57 | // no-op for backends that have not implemented startCoalescing |
58 | } |
59 | |
60 | virtual void endCoalescing( |
61 | std::vector<c10::intrusive_ptr<Work>>& /* reqs */) { |
62 | // no-op for backends that have not implemented endCoalescing |
63 | } |
64 | |
65 | // Subclasses must override this method to return the backend name |
66 | virtual const std::string getBackendName() const { |
67 | TORCH_INTERNAL_ASSERT(false, "getBackendName is not implemented." ); |
68 | }; |
69 | |
70 | virtual c10::intrusive_ptr<Work> broadcast( |
71 | std::vector<at::Tensor>& /* tensors */, |
72 | const BroadcastOptions& /* opts */ = BroadcastOptions()) { |
73 | TORCH_CHECK( |
74 | false, |
75 | c10::str("Backend " , getBackendName(), "does not support broadcast" )); |
76 | } |
77 | |
78 | virtual c10::intrusive_ptr<Work> allreduce( |
79 | std::vector<at::Tensor>& /* tensors */, |
80 | const AllreduceOptions& /* opts */ = AllreduceOptions()) { |
81 | TORCH_CHECK( |
82 | false, |
83 | c10::str("Backend " , getBackendName(), "does not support allreduce" )); |
84 | } |
85 | |
86 | virtual c10::intrusive_ptr<Work> allreduce_coalesced( |
87 | std::vector<at::Tensor>& /* tensors */, |
88 | const AllreduceCoalescedOptions& /* opts */ = |
89 | AllreduceCoalescedOptions()) { |
90 | TORCH_CHECK( |
91 | false, |
92 | c10::str( |
93 | "Backend " , |
94 | getBackendName(), |
95 | "does not support allreduce_coalesced" )); |
96 | } |
97 | |
98 | virtual c10::intrusive_ptr<Work> reduce( |
99 | std::vector<at::Tensor>& /* tensors */, |
100 | const ReduceOptions& /* opts */ = ReduceOptions()) { |
101 | TORCH_CHECK( |
102 | false, |
103 | c10::str("Backend " , getBackendName(), "does not support reduce" )); |
104 | } |
105 | |
106 | virtual c10::intrusive_ptr<Work> allgather( |
107 | std::vector<std::vector<at::Tensor>>& /* outputTensors */, |
108 | std::vector<at::Tensor>& /* inputTensors */, |
109 | const AllgatherOptions& /* opts */ = AllgatherOptions()) { |
110 | TORCH_CHECK( |
111 | false, |
112 | c10::str("Backend " , getBackendName(), "does not support allgather" )); |
113 | } |
114 | |
115 | // Gathers a single tensor inputBuffer into a single buffer outputBuffer that |
116 | // is interpreted as a contigious collection of size inputBuffer * WORLD_SIZE. |
117 | // For implementers of ProcessGroup API and advanced users only. |
118 | // Note: this function will be deprecated in near future. |
119 | virtual c10::intrusive_ptr<Work> _allgather_base( |
120 | at::Tensor& /* outputBuffer */, |
121 | at::Tensor& /* inputBuffer */, |
122 | const AllgatherOptions& /* opts */ = AllgatherOptions()) { |
123 | TORCH_CHECK( |
124 | false, |
125 | c10::str( |
126 | "Backend " , getBackendName(), "does not support _allgather_base" )); |
127 | } |
128 | |
129 | // This function is deprecated and will be moved out of Backend to comms: |
130 | // * do not add dependencies on this function, |
131 | // * do not implement it in your Backend, implement _allgather_base |
132 | // instead. |
133 | virtual c10::intrusive_ptr<Work> allgather_coalesced( |
134 | std::vector<std::vector<at::Tensor>>& /* outputTensorLists */, |
135 | std::vector<at::Tensor>& /* inputTensors */, |
136 | const AllgatherOptions& /* opts */ = AllgatherOptions()) { |
137 | TORCH_CHECK( |
138 | false, |
139 | c10::str( |
140 | "Backend " , |
141 | getBackendName(), |
142 | "does not support allgather_coalesced" )); |
143 | } |
144 | |
145 | virtual c10::intrusive_ptr<Work> gather( |
146 | std::vector<std::vector<at::Tensor>>& /* outputTensors */, |
147 | std::vector<at::Tensor>& /* inputTensors */, |
148 | const GatherOptions& /* opts */ = GatherOptions()) { |
149 | TORCH_CHECK( |
150 | false, |
151 | c10::str("Backend " , getBackendName(), "does not support gather" )); |
152 | } |
153 | |
154 | virtual c10::intrusive_ptr<Work> scatter( |
155 | std::vector<at::Tensor>& /* outputTensors */, |
156 | std::vector<std::vector<at::Tensor>>& /* inputTensors */, |
157 | const ScatterOptions& /* opts */ = ScatterOptions()) { |
158 | TORCH_CHECK( |
159 | false, |
160 | c10::str("Backend " , getBackendName(), "does not support scatter" )); |
161 | } |
162 | |
163 | virtual c10::intrusive_ptr<Work> reduce_scatter( |
164 | std::vector<at::Tensor>& /* outputTensors */, |
165 | std::vector<std::vector<at::Tensor>>& /* inputTensors */, |
166 | const ReduceScatterOptions& /* opts */ = ReduceScatterOptions()) { |
167 | TORCH_CHECK( |
168 | false, |
169 | c10::str( |
170 | "Backend " , getBackendName(), "does not support reduce_scatter" )); |
171 | } |
172 | |
173 | virtual c10::intrusive_ptr<Work> _reduce_scatter_base( |
174 | at::Tensor& /* outputBuffer */, |
175 | at::Tensor& /* inputBuffer */, |
176 | const ReduceScatterOptions& /* opts */ = ReduceScatterOptions()) { |
177 | TORCH_CHECK( |
178 | false, |
179 | c10::str( |
180 | "Backend " , |
181 | getBackendName(), |
182 | "does not support _reduce_scatter_base" )); |
183 | } |
184 | |
185 | virtual c10::intrusive_ptr<Work> alltoall_base( |
186 | at::Tensor& /* outputBuffer */, |
187 | at::Tensor& /* inputBuffer */, |
188 | std::vector<int64_t>& /* outputSplitSizes */, |
189 | std::vector<int64_t>& /* inputSplitSizes */, |
190 | const AllToAllOptions& /* opts */ = AllToAllOptions()) { |
191 | TORCH_CHECK( |
192 | false, |
193 | c10::str( |
194 | "Backend " , getBackendName(), "does not support alltoall_base" )); |
195 | } |
196 | |
197 | virtual c10::intrusive_ptr<Work> alltoall( |
198 | std::vector<at::Tensor>& /* outputTensors */, |
199 | std::vector<at::Tensor>& /* inputTensors */, |
200 | const AllToAllOptions& opts = AllToAllOptions()) { |
201 | TORCH_CHECK( |
202 | false, |
203 | c10::str("Backend " , getBackendName(), "does not support alltoall" )); |
204 | } |
205 | |
206 | virtual void monitoredBarrier( |
207 | const BarrierOptions& /* unused */, |
208 | bool /* unused */ = false) { |
209 | auto backendName = getBackendName(); |
210 | TORCH_CHECK( |
211 | false, |
212 | c10::str( |
213 | "Backend " , |
214 | backendName, |
215 | " does not support monitoredBarrier, only GLOO supports monitored barrier." )); |
216 | } |
217 | |
218 | // Agrees on an initial sequence number for the whole group by having rank 0 |
219 | // create it and broadcast it to other ranks using the store. Only implemented |
220 | // for GLOO and NCCL backends currently. |
221 | virtual void setSequenceNumberForGroup() { |
222 | auto backendName = getBackendName(); |
223 | TORCH_CHECK( |
224 | false, |
225 | c10::str( |
226 | "Backend " , |
227 | backendName, |
228 | " does not yet support sequence numbers." )); |
229 | } |
230 | |
231 | // Retrieves the current sequence number for the whole group, which should be |
232 | // in sync. If the returned number is not consistent across the group, it |
233 | // may indicate that there is some sort of collective desynchronization. |
234 | virtual uint64_t getSequenceNumberForGroup() { |
235 | auto backendName = getBackendName(); |
236 | TORCH_CHECK( |
237 | false, |
238 | c10::str( |
239 | "Backend " , |
240 | backendName, |
241 | " does not yet support sequence numbers." )); |
242 | } |
243 | |
244 | virtual c10::intrusive_ptr<Work> send( |
245 | std::vector<at::Tensor>& /* tensors */, |
246 | int /* dstRank */, |
247 | int /* tag */) { |
248 | TORCH_CHECK( |
249 | false, c10::str("Backend " , getBackendName(), "does not support send" )); |
250 | } |
251 | |
252 | virtual c10::intrusive_ptr<Work> recv( |
253 | std::vector<at::Tensor>& /* tensors */, |
254 | int /* srcRank */, |
255 | int /* tag */) { |
256 | TORCH_CHECK( |
257 | false, c10::str("Backend " , getBackendName(), "does not support recv" )); |
258 | } |
259 | |
260 | virtual c10::intrusive_ptr<Work> recvAnysource( |
261 | std::vector<at::Tensor>& /* tensors */, |
262 | int /* tag */) { |
263 | TORCH_CHECK( |
264 | false, |
265 | c10::str( |
266 | "Backend " , getBackendName(), "does not support recvAnysource" )); |
267 | } |
268 | |
269 | virtual c10::intrusive_ptr<Work> barrier( |
270 | const BarrierOptions& /* opts */ = BarrierOptions()) { |
271 | TORCH_CHECK( |
272 | false, |
273 | c10::str("Backend " , getBackendName(), "does not support barrier" )); |
274 | } |
275 | |
276 | protected: |
277 | // Implementations of this interface need to call this to setup |
278 | // appropriate logging etc. |
279 | void init(); |
280 | |
281 | // Optional sequence number structure for matching collectives. |
282 | c10::optional<c10d::SequenceNum> sequenceNum_ = c10::nullopt; |
283 | const int rank_; |
284 | const int size_; |
285 | // Debug level setting. It is parsed once when ProcessGroup is constructed and |
286 | // remains the same across use of this process group. |
287 | DebugLevel dist_debug_level_; |
288 | }; |
289 | |
290 | } // namespace c10d |
291 | |