1#pragma once
2
3#ifdef USE_C10D_GLOO
4
5#include <condition_variable>
6#include <deque>
7#include <mutex>
8#include <thread>
9#include <unordered_map>
10#include <vector>
11
12#include <gloo/rendezvous/store.h>
13#include <gloo/algorithm.h>
14#include <gloo/common/error.h>
15#include <gloo/context.h>
16#include <gloo/rendezvous/store.h>
17#include <gloo/transport/device.h>
18
19#include <c10/util/hash.h>
20
21#include <torch/csrc/distributed/c10d/Backend.hpp>
22#include <torch/csrc/distributed/c10d/Store.hpp>
23#include <torch/csrc/distributed/c10d/Types.hpp>
24#include <torch/csrc/distributed/c10d/Utils.hpp>
25
26namespace c10d {
27
28constexpr const char* GLOO_BACKEND_NAME = "gloo";
29
30// ProcessGroupGloo implements Gloo bindings for c10d.
31//
32// All functions on this class are expected to be called in the same
33// order across processes in the group. This is the only way that we
34// can guarantee to match up the same calls across processes. For
35// multi-threaded usage of process groups, you can use consider using
36// multiple process group instances.
37//
38// The Gloo algorithms that this class calls into are cached by their
39// signature (see description of AlgorithmKey above). This cache works
40// as follows: every function call instantiates an AlgorithmKey and
41// looks in the cache for existing entries. If there is one, it is
42// removed from the cache and returned to the caller. If there are
43// none, a new entry is created and returned. If an entry was created
44// before, but is still in use, the call will block and wait until the
45// entry is returned to the cache.
46//
47// In the future, we hope to extend this to allow multiple entries per
48// key, to enable parallelism for a single key. The number of entries
49// per key must always be identical for all processes. This maximum
50// number can be automatically tuned, but only if we let a single
51// process take charge, and have it broadcast the limits.
52//
53class TORCH_API ProcessGroupGloo : public Backend {
54 public:
55 // AsyncWork is the Gloo specific superclass for asynchronous work items.
56 // We can split asynchronous work into 3 phases:
57 // 1) Sanity checks and prepare input (e.g. memcpy)
58 // 2) Run operation on background thread
59 // 3) Synchronize with completion on foreground thread
60 //
61 // There is state to be shared between these 3 phases and all of this state
62 // is captured in the AsyncWork class and its derivatives.
63 //
64 // Note: while we are porting operations to use new style collectives, there
65 // is a split between operations using the existing caching approach and
66 // operations using the new AsyncWork base class. Over time we will port
67 // all operations and perform needed cleanup.
68 //
69 // FIXME: This probably should be called WorkGloo since the work is executed in sync mode
70 // by a background thread.
71 class TORCH_API AsyncWork : public Work {
72 public:
73 explicit AsyncWork(
74 std::vector<std::vector<at::Tensor>> outputTensors,
75 const char* profilingTitle = nullptr,
76 const c10::optional<std::vector<at::Tensor>>& inputTensors = c10::nullopt);
77
78 ~AsyncWork() override = default;
79
80 static void execute(c10::intrusive_ptr<AsyncWork> work);
81
82 virtual void run() = 0;
83
84 std::vector<at::Tensor> result() override;
85
86 c10::intrusive_ptr<c10::ivalue::Future> getFuture() override;
87
88 protected:
89 friend class ProcessGroupGloo;
90
91 private:
92 void finishWorkGloo();
93 void finishWorkGlooError(std::exception_ptr eptr);
94 inline void recordAsyncWorkProfilingInfo(
95 const char* profilingTitle,
96 const c10::optional<std::vector<at::Tensor>>& inputTensors);
97
98 const std::vector<std::vector<at::Tensor>> outputTensors_;
99 c10::intrusive_ptr<at::ivalue::Future> future_;
100 std::function<void()> recordFunctionBeforeCallback_;
101 };
102
103 // Wrap c10d store as Gloo store
104 class TORCH_API GlooStore : public ::gloo::rendezvous::Store {
105 public:
106 GlooStore(const c10::intrusive_ptr<::c10d::Store>& store) : store_(store) {}
107
108 void setUint(const std::string& key, const std::vector<uint8_t>& value) {
109 store_->set(key, value);
110 }
111
112 void set(const std::string& key, const std::vector<char>& value) override {
113 std::vector<uint8_t> tmp(value.begin(), value.end());
114 store_->set(key, tmp);
115 }
116
117 std::vector<uint8_t> getUint(const std::string& key) {
118 auto value = store_->get(key);
119 return value;
120 }
121
122 std::vector<char> get(const std::string& key) override {
123 auto value = store_->get(key);
124 return std::vector<char>(value.begin(), value.end());
125 }
126
127 void wait(const std::vector<std::string>& keys) override {
128 store_->wait(keys, Store::kDefaultTimeout);
129 }
130
131 void wait(
132 const std::vector<std::string>& keys,
133 const std::chrono::milliseconds& timeout) override {
134 store_->wait(keys, timeout);
135 }
136
137 protected:
138 c10::intrusive_ptr<::c10d::Store> store_;
139 };
140
141 // For send and recv operations there is no need to pass them to the
142 // thread pool as they are entirely completed by the device thread.
143 // This work object is used to synchronize completion of the send or
144 // recv operation. It keeps a reference to the tensor it is
145 // operating on to prevent it from being deallocated while the
146 // operation is still in flight.
147 class TORCH_API SendWork : public Work {
148 public:
149 explicit SendWork(
150 at::Tensor& tensor,
151 std::unique_ptr<::gloo::transport::UnboundBuffer> buffer);
152
153 bool wait(std::chrono::milliseconds timeout = kNoTimeout) override;
154
155 void abort() override;
156
157 protected:
158 at::Tensor tensor_;
159 std::unique_ptr<::gloo::transport::UnboundBuffer> buffer_;
160 };
161
162 class TORCH_API RecvWork : public Work {
163 public:
164 explicit RecvWork(
165 at::Tensor& tensor,
166 std::unique_ptr<::gloo::transport::UnboundBuffer> buffer,
167 const char* profilingTitle = nullptr);
168
169 int sourceRank() const override;
170
171 bool wait(std::chrono::milliseconds timeout = kNoTimeout) override;
172
173 void abort() override;
174
175 protected:
176 at::Tensor tensor_;
177 std::unique_ptr<::gloo::transport::UnboundBuffer> buffer_;
178 int srcRank_;
179 };
180
181 struct TORCH_API Options : public Backend::Options {
182 explicit Options(
183 std::chrono::milliseconds timeout = kBackendDefaultTimeout);
184
185 // return intrusive_ptr of the object
186 static c10::intrusive_ptr<Options> create(
187 std::chrono::milliseconds timeout = kBackendDefaultTimeout) {
188 return c10::make_intrusive<Options>(timeout);
189 }
190
191 std::vector<std::shared_ptr<::gloo::transport::Device>> devices;
192 int threads;
193 };
194
195 const std::string getBackendName() const override {
196 return std::string(GLOO_BACKEND_NAME);
197 }
198
199 // Helper functions to create a new device object.
200 // They are static functions on this class to keep them logically
201 // separate from the rest of the code base (e.g. torch/csrc/distributed).
202
203 // Create new device instance for specific interface.
204 static std::shared_ptr<::gloo::transport::Device> createDeviceForInterface(
205 const std::string& interface);
206
207 // Create new device instance for specific hostname or address.
208 static std::shared_ptr<::gloo::transport::Device> createDeviceForHostname(
209 const std::string& hostname);
210
211 // Create new device instance.
212 // It tries to resolve this machine's hostname and bind to that address.
213 // If that fails (i.e. the hostname doesn't resolve to an address), it
214 // falls back to binding to the loopback address.
215 static std::shared_ptr<::gloo::transport::Device> createDefaultDevice();
216
217 // Create ProcessGroupGloo instance.
218 static c10::intrusive_ptr<ProcessGroupGloo> createProcessGroupGloo(
219 const c10::intrusive_ptr<Store>& store,
220 int rank,
221 int size,
222 std::chrono::milliseconds timeout);
223
224 explicit ProcessGroupGloo(
225 const c10::intrusive_ptr<Store>& store,
226 int rank,
227 int size,
228 c10::intrusive_ptr<Options> options = Options::create());
229
230 ~ProcessGroupGloo() override;
231
232 c10::intrusive_ptr<Options> getOptions() {
233 return options_;
234 }
235
236 c10::intrusive_ptr<Work> broadcast(
237 std::vector<at::Tensor>& tensors,
238 const BroadcastOptions& opts = BroadcastOptions()) override;
239
240 c10::intrusive_ptr<Work> allreduce(
241 std::vector<at::Tensor>& tensors,
242 const AllreduceOptions& opts = AllreduceOptions()) override;
243
244 c10::intrusive_ptr<Work> allreduce_coalesced(
245 std::vector<at::Tensor>& tensors,
246 const AllreduceCoalescedOptions& opts =
247 AllreduceCoalescedOptions()) override;
248
249 c10::intrusive_ptr<Work> reduce(
250 std::vector<at::Tensor>& tensors,
251 const ReduceOptions& opts = ReduceOptions()) override;
252
253 c10::intrusive_ptr<Work> allgather(
254 std::vector<std::vector<at::Tensor>>& outputs,
255 std::vector<at::Tensor>& inputs,
256 const AllgatherOptions& opts = AllgatherOptions()) override;
257
258 c10::intrusive_ptr<Work> _allgather_base(
259 at::Tensor& outputBuffer,
260 at::Tensor& inputBuffer,
261 const AllgatherOptions& opts = AllgatherOptions()) override;
262
263 c10::intrusive_ptr<Work> allgather_coalesced(
264 std::vector<std::vector<at::Tensor>>& output_lists,
265 std::vector<at::Tensor>& input_list,
266 const AllgatherOptions& opts = AllgatherOptions()) override;
267
268 c10::intrusive_ptr<Work> gather(
269 std::vector<std::vector<at::Tensor>>& outputs,
270 std::vector<at::Tensor>& inputs,
271 const GatherOptions& opts = GatherOptions()) override;
272
273 c10::intrusive_ptr<Work> scatter(
274 std::vector<at::Tensor>& outputs,
275 std::vector<std::vector<at::Tensor>>& inputs,
276 const ScatterOptions& opts = ScatterOptions()) override;
277
278 c10::intrusive_ptr<Work> reduce_scatter(
279 std::vector<at::Tensor>& outputs,
280 std::vector<std::vector<at::Tensor>>& inputs,
281 const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
282
283 c10::intrusive_ptr<Work> alltoall_base(
284 at::Tensor& outputTensor,
285 at::Tensor& inputTensor,
286 std::vector<int64_t>& outputCounts,
287 std::vector<int64_t>& inputCounts,
288 const AllToAllOptions& opts = AllToAllOptions()) override;
289
290 c10::intrusive_ptr<Work> send(
291 std::vector<at::Tensor>& tensors,
292 int dstRank,
293 int tag) override;
294
295 c10::intrusive_ptr<Work> recv(
296 std::vector<at::Tensor>& tensors,
297 int srcRank,
298 int tag) override;
299
300 c10::intrusive_ptr<Work> recvAnysource(
301 std::vector<at::Tensor>& tensors,
302 int tag) override;
303
304 c10::intrusive_ptr<Work> barrier(
305 const BarrierOptions& opts = BarrierOptions()) override;
306
307 const std::unique_ptr<::gloo::rendezvous::Store>& _getStore() const {
308 return store_;
309 }
310
311 // Similar to barrier(), but blocks rank 0 until all other ranks have
312 // acknowledged that they are alive (through send/recv from rank 0). Rank 0
313 // is able to report all failed ranks if waitAllRanks = true, otherwise
314 // reports the first rank it detected as failed.
315 void monitoredBarrier(
316 const BarrierOptions& opts = BarrierOptions(),
317 bool waitAllRanks = false) override;
318
319 // Agrees on an initial sequence number for the whole group by having rank 0
320 // create it and broadcast it to other ranks using the store.
321 void setSequenceNumberForGroup() override;
322
323 // Retrieves the current sequence number for the whole group, which should be
324 // in sync. If the returned number is not consistent across the group, it
325 // may indicate that there is some sort of collective desynchronization.
326 uint64_t getSequenceNumberForGroup() override;
327
328 int getNumThreads() {
329 return options_->threads;
330 }
331
332 protected:
333 std::unique_ptr<::gloo::rendezvous::Store> store_;
334 const c10::intrusive_ptr<Options> options_;
335
336 // Every Gloo context represents a set of connections to its peers.
337 // In order to use more than one device (or allow for parallelism on
338 // a single device), you need multiple contexts.
339 std::vector<std::shared_ptr<::gloo::Context>> contexts_;
340 std::vector<std::thread> threads_;
341 bool stop_;
342
343 // Incremented for every collective we kick off.
344 // The value is used as tag for collective operations. Collectives are kicked
345 // off in identical order across processes. Therefore the tag can be used
346 // to match up operations during concurrent execution.
347 uint32_t collectiveCounter_;
348
349 // Returns next collective tag to use (uses collectiveCounter_).
350 uint32_t nextTag();
351
352 // Returns the context to use for the specified tag.
353 // With `nextTag` returning an increasing number, this should lead
354 // to contexts being used in a round-robin fashion.
355 std::shared_ptr<::gloo::Context> getContext(uint32_t tag);
356
357 // Entrypoint for worker threads.
358 void runLoop(int workerIndex);
359
360 // Queue work to run on worker thread.
361 void enqueue(c10::intrusive_ptr<AsyncWork> work);
362
363 // Keep both a queue of pending work, and a vector with in progress work.
364 // Both of these can only be mutated when holding the queue lock.
365 // We keep both around instead of just the queue, so we can grab a weak_ptr
366 // to all in progress and pending work when executing a barrier.
367 // When executing a barrier, we need to ensure that all prior work
368 // has completed before completing itself.
369 std::deque<c10::intrusive_ptr<AsyncWork>> workQueue_;
370 std::vector<c10::intrusive_ptr<AsyncWork>> workInProgress_;
371 std::mutex workMutex_;
372 std::condition_variable workProduceCV_;
373 std::condition_variable workConsumeCV_;
374};
375
376} // namespace c10d
377
378#endif // USE_C10D_GLOO
379