1 | #pragma once |
2 | |
3 | #ifdef USE_C10D_GLOO |
4 | |
5 | #include <condition_variable> |
6 | #include <deque> |
7 | #include <mutex> |
8 | #include <thread> |
9 | #include <unordered_map> |
10 | #include <vector> |
11 | |
12 | #include <gloo/rendezvous/store.h> |
13 | #include <gloo/algorithm.h> |
14 | #include <gloo/common/error.h> |
15 | #include <gloo/context.h> |
16 | #include <gloo/rendezvous/store.h> |
17 | #include <gloo/transport/device.h> |
18 | |
19 | #include <c10/util/hash.h> |
20 | |
21 | #include <torch/csrc/distributed/c10d/Backend.hpp> |
22 | #include <torch/csrc/distributed/c10d/Store.hpp> |
23 | #include <torch/csrc/distributed/c10d/Types.hpp> |
24 | #include <torch/csrc/distributed/c10d/Utils.hpp> |
25 | |
26 | namespace c10d { |
27 | |
28 | constexpr const char* GLOO_BACKEND_NAME = "gloo" ; |
29 | |
30 | // ProcessGroupGloo implements Gloo bindings for c10d. |
31 | // |
32 | // All functions on this class are expected to be called in the same |
33 | // order across processes in the group. This is the only way that we |
34 | // can guarantee to match up the same calls across processes. For |
35 | // multi-threaded usage of process groups, you can use consider using |
36 | // multiple process group instances. |
37 | // |
38 | // The Gloo algorithms that this class calls into are cached by their |
39 | // signature (see description of AlgorithmKey above). This cache works |
40 | // as follows: every function call instantiates an AlgorithmKey and |
41 | // looks in the cache for existing entries. If there is one, it is |
42 | // removed from the cache and returned to the caller. If there are |
43 | // none, a new entry is created and returned. If an entry was created |
44 | // before, but is still in use, the call will block and wait until the |
45 | // entry is returned to the cache. |
46 | // |
47 | // In the future, we hope to extend this to allow multiple entries per |
48 | // key, to enable parallelism for a single key. The number of entries |
49 | // per key must always be identical for all processes. This maximum |
50 | // number can be automatically tuned, but only if we let a single |
51 | // process take charge, and have it broadcast the limits. |
52 | // |
53 | class TORCH_API ProcessGroupGloo : public Backend { |
54 | public: |
55 | // AsyncWork is the Gloo specific superclass for asynchronous work items. |
56 | // We can split asynchronous work into 3 phases: |
57 | // 1) Sanity checks and prepare input (e.g. memcpy) |
58 | // 2) Run operation on background thread |
59 | // 3) Synchronize with completion on foreground thread |
60 | // |
61 | // There is state to be shared between these 3 phases and all of this state |
62 | // is captured in the AsyncWork class and its derivatives. |
63 | // |
64 | // Note: while we are porting operations to use new style collectives, there |
65 | // is a split between operations using the existing caching approach and |
66 | // operations using the new AsyncWork base class. Over time we will port |
67 | // all operations and perform needed cleanup. |
68 | // |
69 | // FIXME: This probably should be called WorkGloo since the work is executed in sync mode |
70 | // by a background thread. |
71 | class TORCH_API AsyncWork : public Work { |
72 | public: |
73 | explicit AsyncWork( |
74 | std::vector<std::vector<at::Tensor>> outputTensors, |
75 | const char* profilingTitle = nullptr, |
76 | const c10::optional<std::vector<at::Tensor>>& inputTensors = c10::nullopt); |
77 | |
78 | ~AsyncWork() override = default; |
79 | |
80 | static void execute(c10::intrusive_ptr<AsyncWork> work); |
81 | |
82 | virtual void run() = 0; |
83 | |
84 | std::vector<at::Tensor> result() override; |
85 | |
86 | c10::intrusive_ptr<c10::ivalue::Future> getFuture() override; |
87 | |
88 | protected: |
89 | friend class ProcessGroupGloo; |
90 | |
91 | private: |
92 | void finishWorkGloo(); |
93 | void finishWorkGlooError(std::exception_ptr eptr); |
94 | inline void recordAsyncWorkProfilingInfo( |
95 | const char* profilingTitle, |
96 | const c10::optional<std::vector<at::Tensor>>& inputTensors); |
97 | |
98 | const std::vector<std::vector<at::Tensor>> outputTensors_; |
99 | c10::intrusive_ptr<at::ivalue::Future> future_; |
100 | std::function<void()> recordFunctionBeforeCallback_; |
101 | }; |
102 | |
103 | // Wrap c10d store as Gloo store |
104 | class TORCH_API GlooStore : public ::gloo::rendezvous::Store { |
105 | public: |
106 | GlooStore(const c10::intrusive_ptr<::c10d::Store>& store) : store_(store) {} |
107 | |
108 | void setUint(const std::string& key, const std::vector<uint8_t>& value) { |
109 | store_->set(key, value); |
110 | } |
111 | |
112 | void set(const std::string& key, const std::vector<char>& value) override { |
113 | std::vector<uint8_t> tmp(value.begin(), value.end()); |
114 | store_->set(key, tmp); |
115 | } |
116 | |
117 | std::vector<uint8_t> getUint(const std::string& key) { |
118 | auto value = store_->get(key); |
119 | return value; |
120 | } |
121 | |
122 | std::vector<char> get(const std::string& key) override { |
123 | auto value = store_->get(key); |
124 | return std::vector<char>(value.begin(), value.end()); |
125 | } |
126 | |
127 | void wait(const std::vector<std::string>& keys) override { |
128 | store_->wait(keys, Store::kDefaultTimeout); |
129 | } |
130 | |
131 | void wait( |
132 | const std::vector<std::string>& keys, |
133 | const std::chrono::milliseconds& timeout) override { |
134 | store_->wait(keys, timeout); |
135 | } |
136 | |
137 | protected: |
138 | c10::intrusive_ptr<::c10d::Store> store_; |
139 | }; |
140 | |
141 | // For send and recv operations there is no need to pass them to the |
142 | // thread pool as they are entirely completed by the device thread. |
143 | // This work object is used to synchronize completion of the send or |
144 | // recv operation. It keeps a reference to the tensor it is |
145 | // operating on to prevent it from being deallocated while the |
146 | // operation is still in flight. |
147 | class TORCH_API SendWork : public Work { |
148 | public: |
149 | explicit SendWork( |
150 | at::Tensor& tensor, |
151 | std::unique_ptr<::gloo::transport::UnboundBuffer> buffer); |
152 | |
153 | bool wait(std::chrono::milliseconds timeout = kNoTimeout) override; |
154 | |
155 | void abort() override; |
156 | |
157 | protected: |
158 | at::Tensor tensor_; |
159 | std::unique_ptr<::gloo::transport::UnboundBuffer> buffer_; |
160 | }; |
161 | |
162 | class TORCH_API RecvWork : public Work { |
163 | public: |
164 | explicit RecvWork( |
165 | at::Tensor& tensor, |
166 | std::unique_ptr<::gloo::transport::UnboundBuffer> buffer, |
167 | const char* profilingTitle = nullptr); |
168 | |
169 | int sourceRank() const override; |
170 | |
171 | bool wait(std::chrono::milliseconds timeout = kNoTimeout) override; |
172 | |
173 | void abort() override; |
174 | |
175 | protected: |
176 | at::Tensor tensor_; |
177 | std::unique_ptr<::gloo::transport::UnboundBuffer> buffer_; |
178 | int srcRank_; |
179 | }; |
180 | |
181 | struct TORCH_API Options : public Backend::Options { |
182 | explicit Options( |
183 | std::chrono::milliseconds timeout = kBackendDefaultTimeout); |
184 | |
185 | // return intrusive_ptr of the object |
186 | static c10::intrusive_ptr<Options> create( |
187 | std::chrono::milliseconds timeout = kBackendDefaultTimeout) { |
188 | return c10::make_intrusive<Options>(timeout); |
189 | } |
190 | |
191 | std::vector<std::shared_ptr<::gloo::transport::Device>> devices; |
192 | int threads; |
193 | }; |
194 | |
195 | const std::string getBackendName() const override { |
196 | return std::string(GLOO_BACKEND_NAME); |
197 | } |
198 | |
199 | // Helper functions to create a new device object. |
200 | // They are static functions on this class to keep them logically |
201 | // separate from the rest of the code base (e.g. torch/csrc/distributed). |
202 | |
203 | // Create new device instance for specific interface. |
204 | static std::shared_ptr<::gloo::transport::Device> createDeviceForInterface( |
205 | const std::string& interface); |
206 | |
207 | // Create new device instance for specific hostname or address. |
208 | static std::shared_ptr<::gloo::transport::Device> createDeviceForHostname( |
209 | const std::string& hostname); |
210 | |
211 | // Create new device instance. |
212 | // It tries to resolve this machine's hostname and bind to that address. |
213 | // If that fails (i.e. the hostname doesn't resolve to an address), it |
214 | // falls back to binding to the loopback address. |
215 | static std::shared_ptr<::gloo::transport::Device> createDefaultDevice(); |
216 | |
217 | // Create ProcessGroupGloo instance. |
218 | static c10::intrusive_ptr<ProcessGroupGloo> createProcessGroupGloo( |
219 | const c10::intrusive_ptr<Store>& store, |
220 | int rank, |
221 | int size, |
222 | std::chrono::milliseconds timeout); |
223 | |
224 | explicit ProcessGroupGloo( |
225 | const c10::intrusive_ptr<Store>& store, |
226 | int rank, |
227 | int size, |
228 | c10::intrusive_ptr<Options> options = Options::create()); |
229 | |
230 | ~ProcessGroupGloo() override; |
231 | |
232 | c10::intrusive_ptr<Options> getOptions() { |
233 | return options_; |
234 | } |
235 | |
236 | c10::intrusive_ptr<Work> broadcast( |
237 | std::vector<at::Tensor>& tensors, |
238 | const BroadcastOptions& opts = BroadcastOptions()) override; |
239 | |
240 | c10::intrusive_ptr<Work> allreduce( |
241 | std::vector<at::Tensor>& tensors, |
242 | const AllreduceOptions& opts = AllreduceOptions()) override; |
243 | |
244 | c10::intrusive_ptr<Work> allreduce_coalesced( |
245 | std::vector<at::Tensor>& tensors, |
246 | const AllreduceCoalescedOptions& opts = |
247 | AllreduceCoalescedOptions()) override; |
248 | |
249 | c10::intrusive_ptr<Work> reduce( |
250 | std::vector<at::Tensor>& tensors, |
251 | const ReduceOptions& opts = ReduceOptions()) override; |
252 | |
253 | c10::intrusive_ptr<Work> allgather( |
254 | std::vector<std::vector<at::Tensor>>& outputs, |
255 | std::vector<at::Tensor>& inputs, |
256 | const AllgatherOptions& opts = AllgatherOptions()) override; |
257 | |
258 | c10::intrusive_ptr<Work> _allgather_base( |
259 | at::Tensor& outputBuffer, |
260 | at::Tensor& inputBuffer, |
261 | const AllgatherOptions& opts = AllgatherOptions()) override; |
262 | |
263 | c10::intrusive_ptr<Work> allgather_coalesced( |
264 | std::vector<std::vector<at::Tensor>>& output_lists, |
265 | std::vector<at::Tensor>& input_list, |
266 | const AllgatherOptions& opts = AllgatherOptions()) override; |
267 | |
268 | c10::intrusive_ptr<Work> gather( |
269 | std::vector<std::vector<at::Tensor>>& outputs, |
270 | std::vector<at::Tensor>& inputs, |
271 | const GatherOptions& opts = GatherOptions()) override; |
272 | |
273 | c10::intrusive_ptr<Work> scatter( |
274 | std::vector<at::Tensor>& outputs, |
275 | std::vector<std::vector<at::Tensor>>& inputs, |
276 | const ScatterOptions& opts = ScatterOptions()) override; |
277 | |
278 | c10::intrusive_ptr<Work> reduce_scatter( |
279 | std::vector<at::Tensor>& outputs, |
280 | std::vector<std::vector<at::Tensor>>& inputs, |
281 | const ReduceScatterOptions& opts = ReduceScatterOptions()) override; |
282 | |
283 | c10::intrusive_ptr<Work> alltoall_base( |
284 | at::Tensor& outputTensor, |
285 | at::Tensor& inputTensor, |
286 | std::vector<int64_t>& outputCounts, |
287 | std::vector<int64_t>& inputCounts, |
288 | const AllToAllOptions& opts = AllToAllOptions()) override; |
289 | |
290 | c10::intrusive_ptr<Work> send( |
291 | std::vector<at::Tensor>& tensors, |
292 | int dstRank, |
293 | int tag) override; |
294 | |
295 | c10::intrusive_ptr<Work> recv( |
296 | std::vector<at::Tensor>& tensors, |
297 | int srcRank, |
298 | int tag) override; |
299 | |
300 | c10::intrusive_ptr<Work> recvAnysource( |
301 | std::vector<at::Tensor>& tensors, |
302 | int tag) override; |
303 | |
304 | c10::intrusive_ptr<Work> barrier( |
305 | const BarrierOptions& opts = BarrierOptions()) override; |
306 | |
307 | const std::unique_ptr<::gloo::rendezvous::Store>& _getStore() const { |
308 | return store_; |
309 | } |
310 | |
311 | // Similar to barrier(), but blocks rank 0 until all other ranks have |
312 | // acknowledged that they are alive (through send/recv from rank 0). Rank 0 |
313 | // is able to report all failed ranks if waitAllRanks = true, otherwise |
314 | // reports the first rank it detected as failed. |
315 | void monitoredBarrier( |
316 | const BarrierOptions& opts = BarrierOptions(), |
317 | bool waitAllRanks = false) override; |
318 | |
319 | // Agrees on an initial sequence number for the whole group by having rank 0 |
320 | // create it and broadcast it to other ranks using the store. |
321 | void setSequenceNumberForGroup() override; |
322 | |
323 | // Retrieves the current sequence number for the whole group, which should be |
324 | // in sync. If the returned number is not consistent across the group, it |
325 | // may indicate that there is some sort of collective desynchronization. |
326 | uint64_t getSequenceNumberForGroup() override; |
327 | |
328 | int getNumThreads() { |
329 | return options_->threads; |
330 | } |
331 | |
332 | protected: |
333 | std::unique_ptr<::gloo::rendezvous::Store> store_; |
334 | const c10::intrusive_ptr<Options> options_; |
335 | |
336 | // Every Gloo context represents a set of connections to its peers. |
337 | // In order to use more than one device (or allow for parallelism on |
338 | // a single device), you need multiple contexts. |
339 | std::vector<std::shared_ptr<::gloo::Context>> contexts_; |
340 | std::vector<std::thread> threads_; |
341 | bool stop_; |
342 | |
343 | // Incremented for every collective we kick off. |
344 | // The value is used as tag for collective operations. Collectives are kicked |
345 | // off in identical order across processes. Therefore the tag can be used |
346 | // to match up operations during concurrent execution. |
347 | uint32_t collectiveCounter_; |
348 | |
349 | // Returns next collective tag to use (uses collectiveCounter_). |
350 | uint32_t nextTag(); |
351 | |
352 | // Returns the context to use for the specified tag. |
353 | // With `nextTag` returning an increasing number, this should lead |
354 | // to contexts being used in a round-robin fashion. |
355 | std::shared_ptr<::gloo::Context> getContext(uint32_t tag); |
356 | |
357 | // Entrypoint for worker threads. |
358 | void runLoop(int workerIndex); |
359 | |
360 | // Queue work to run on worker thread. |
361 | void enqueue(c10::intrusive_ptr<AsyncWork> work); |
362 | |
363 | // Keep both a queue of pending work, and a vector with in progress work. |
364 | // Both of these can only be mutated when holding the queue lock. |
365 | // We keep both around instead of just the queue, so we can grab a weak_ptr |
366 | // to all in progress and pending work when executing a barrier. |
367 | // When executing a barrier, we need to ensure that all prior work |
368 | // has completed before completing itself. |
369 | std::deque<c10::intrusive_ptr<AsyncWork>> workQueue_; |
370 | std::vector<c10::intrusive_ptr<AsyncWork>> workInProgress_; |
371 | std::mutex workMutex_; |
372 | std::condition_variable workProduceCV_; |
373 | std::condition_variable workConsumeCV_; |
374 | }; |
375 | |
376 | } // namespace c10d |
377 | |
378 | #endif // USE_C10D_GLOO |
379 | |