1#pragma once
2
3#ifdef USE_C10D_MPI
4
5#include <condition_variable>
6#include <deque>
7#include <exception>
8#include <memory>
9#include <mutex>
10#include <thread>
11#include <vector>
12
13#include <ATen/core/ivalue.h>
14#include <ATen/core/ivalue_inl.h>
15
16#include <torch/csrc/distributed/c10d/Backend.hpp>
17#include <torch/csrc/distributed/c10d/Types.hpp>
18#include <torch/csrc/distributed/c10d/Utils.hpp>
19
20#include <c10/util/CallOnce.h>
21
22#include <mpi.h>
23
24namespace c10d {
25
26constexpr const char* MPI_BACKEND_NAME = "mpi";
27
28// WorkEntry is the state associated with a single MPI run instance.
29// It include the source Tensor list and destination Tensor list, as well as
30// The actual run function that will operate either on src or dst or both.
31struct WorkEntry {
32 explicit WorkEntry(
33 std::vector<at::Tensor>* srcPtr,
34 std::vector<at::Tensor>* dstPtr,
35 std::function<void(std::unique_ptr<WorkEntry>&)> run)
36 : dst(dstPtr ? *dstPtr : std::vector<at::Tensor>()),
37 run(std::move(run)) {
38 if (srcPtr) {
39 src = *srcPtr;
40 }
41 }
42
43 // Not copyable
44 WorkEntry(const WorkEntry&) = delete;
45 // Not copy assignable
46 WorkEntry& operator=(const WorkEntry&) = delete;
47
48 // For input and output tensors (in-place), we will always use src
49 std::vector<at::Tensor> src;
50
51 // Copy of user provided outputs.
52 const std::vector<at::Tensor> dst;
53
54 // src rank returned, for recv only
55 int* srcRank = nullptr;
56 std::function<void(std::unique_ptr<WorkEntry>&)> run;
57};
58
59// ProcessGroupMPI implements MPI bindings for c10d.
60//
61// All functions on this class are expected to be called in the same
62// order across processes in the group. This is the only way that we
63// can guarantee to match up the same calls across processes.
64//
65// All MPI functions provided by this class is asynchronously scheduled on a
66// Worker thread. Therefore, ProcessGroupMPI requires the MPI implementation
67// that is used to have a minimum thread support value of MPI_THREAD_SERIALIZED.
68// That is, The process may be multi-threaded, and multiple threads may make
69// MPI calls, but only one at a time: MPI calls are not made concurrently from
70// two distinct threads (all MPI calls are serialized). However, with
71// MPI_THREAD_SERIALIZED, ProcessGroupMPI will only support a singe process
72// group. In other words, no more than 1 process group can be created globally.
73//
74// If you would like to use multiple ProcessGroupMPI, it requres your MPI
75// implemenation to have a thread support value of MPI_THREAD_MULTIPLE, that is,
76// multiple threads may call MPI, with no restriction.
77//
78// Also note that ProcessGroupMPI only supports a single Tensor operation. In
79// other words, the size of the input Tensor vector should always be 1.
80//
81// CUDA tensor can be supported if the MPI used is CUDA-aware MPI, and
82// ProcessGroupMPI will automatically detect this support.
83class TORCH_API ProcessGroupMPI : public Backend {
84 public:
85 class WorkMPI : public Work {
86 public:
87 explicit WorkMPI(
88 std::vector<at::Tensor> outputTensors,
89 const char* profilingTitle = nullptr,
90 const c10::optional<std::vector<at::Tensor>>& inputTensors =
91 c10::nullopt)
92 : Work(-1, OpType::UNKNOWN, profilingTitle, inputTensors),
93 outputTensors_(std::move(outputTensors)),
94 future_(c10::make_intrusive<at::ivalue::Future>(
95 c10::ListType::create(c10::TensorType::get()))) {}
96
97 std::vector<at::Tensor> result() override;
98
99 c10::intrusive_ptr<c10::ivalue::Future> getFuture() override;
100
101 protected:
102 friend class ProcessGroupMPI;
103
104 private:
105 void finishWorkMPI();
106 void finishWorkMPIError(std::exception_ptr eptr);
107
108 std::vector<at::Tensor> outputTensors_;
109 c10::intrusive_ptr<at::ivalue::Future> future_;
110 };
111
112 class AsyncWork : public Work {
113 public:
114 AsyncWork(
115 MPI_Request request,
116 std::vector<at::Tensor> outputTensors,
117 const char* profilingTitle = nullptr,
118 const c10::optional<std::vector<at::Tensor>>& inputTensors =
119 c10::nullopt);
120
121 virtual ~AsyncWork();
122
123 bool isCompleted() override;
124
125 bool isSuccess() const override;
126
127 int sourceRank() const override;
128
129 bool wait(std::chrono::milliseconds timeout = kUnsetTimeout) override;
130
131 void abort() override;
132
133 std::vector<at::Tensor> result() override;
134
135 protected:
136 void populateException();
137
138 private:
139 const std::vector<at::Tensor> outputTensors_;
140 MPI_Request request_;
141 MPI_Status status_;
142 };
143
144 // Constructor will spawn up the worker thread loop
145 explicit ProcessGroupMPI(int rank, int size, MPI_Comm pgComm);
146
147 virtual ~ProcessGroupMPI();
148
149 // Abort the MPI program, needs to be called when exception is detected
150 void abort();
151
152 const std::string getBackendName() const override {
153 return std::string(MPI_BACKEND_NAME);
154 }
155
156 c10::intrusive_ptr<Work> broadcast(
157 std::vector<at::Tensor>& data,
158 const BroadcastOptions& opts = BroadcastOptions()) override;
159
160 c10::intrusive_ptr<Work> allreduce(
161 std::vector<at::Tensor>& tensors,
162 const AllreduceOptions& opts = AllreduceOptions()) override;
163
164 c10::intrusive_ptr<Work> allreduce_coalesced(
165 std::vector<at::Tensor>& tensors,
166 const AllreduceCoalescedOptions& opts =
167 AllreduceCoalescedOptions()) override;
168
169 c10::intrusive_ptr<Work> reduce(
170 std::vector<at::Tensor>& tensors,
171 const ReduceOptions& opts = ReduceOptions()) override;
172
173 c10::intrusive_ptr<Work> allgather(
174 std::vector<std::vector<at::Tensor>>& outputTensors,
175 std::vector<at::Tensor>& inputTensors,
176 const AllgatherOptions& opts = AllgatherOptions()) override;
177
178 c10::intrusive_ptr<Work> _allgather_base(
179 at::Tensor& outputbuffer,
180 at::Tensor& inputbuffer,
181 const AllgatherOptions& opts = AllgatherOptions()) override;
182
183 c10::intrusive_ptr<Work> allgather_coalesced(
184 std::vector<std::vector<at::Tensor>>& outputTensorLists,
185 std::vector<at::Tensor>& inputTensors,
186 const AllgatherOptions& opts = AllgatherOptions()) override;
187
188 c10::intrusive_ptr<Work> gather(
189 std::vector<std::vector<at::Tensor>>& outputTensors,
190 std::vector<at::Tensor>& inputTensors,
191 const GatherOptions& opts = GatherOptions()) override;
192
193 c10::intrusive_ptr<Work> scatter(
194 std::vector<at::Tensor>& outputTensors,
195 std::vector<std::vector<at::Tensor>>& inputTensors,
196 const ScatterOptions& opts = ScatterOptions()) override;
197
198 c10::intrusive_ptr<Work> reduce_scatter(
199 std::vector<at::Tensor>& outputTensors,
200 std::vector<std::vector<at::Tensor>>& inputTensors,
201 const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
202
203 c10::intrusive_ptr<Work> alltoall_base(
204 at::Tensor& outputTensor,
205 at::Tensor& inputTensor,
206 std::vector<int64_t>& outputSplitSizes,
207 std::vector<int64_t>& inputSplitSizes,
208 const AllToAllOptions& opts = AllToAllOptions()) override;
209
210 c10::intrusive_ptr<Work> alltoall(
211 std::vector<at::Tensor>& outputTensors,
212 std::vector<at::Tensor>& inputTensors,
213 const AllToAllOptions& opts = AllToAllOptions()) override;
214
215 c10::intrusive_ptr<Work> send(
216 std::vector<at::Tensor>& tensors,
217 int dstRank,
218 int tag) override;
219
220 c10::intrusive_ptr<Work> recv(
221 std::vector<at::Tensor>& tensors,
222 int srcRank,
223 int tag) override;
224
225 c10::intrusive_ptr<Work> recvAnysource(
226 std::vector<at::Tensor>& tensor,
227 int tag) override;
228
229 c10::intrusive_ptr<Work> barrier(
230 const BarrierOptions& opts = BarrierOptions()) override;
231
232 // Creating a new ProcessGroupMPI, will initiialize MPI if not initialized
233 static c10::intrusive_ptr<ProcessGroupMPI> createProcessGroupMPI(
234 std::vector<int> ranks = {});
235
236 protected:
237 using WorkType =
238 std::tuple<std::unique_ptr<WorkEntry>, c10::intrusive_ptr<WorkMPI>>;
239 // Worker thread loop
240 void runLoop();
241 // Helper function that is called by the destructor
242 void destroy();
243
244 c10::intrusive_ptr<Work> enqueue(
245 std::unique_ptr<WorkEntry> entry,
246 const char* profilingTitle = nullptr,
247 const c10::optional<std::vector<at::Tensor>>& inputTensors = c10::nullopt);
248
249 bool stop_;
250
251 std::mutex pgMutex_;
252 std::thread workerThread_;
253
254 std::deque<WorkType> queue_;
255 std::condition_variable queueProduceCV_;
256 std::condition_variable queueConsumeCV_;
257
258 // Global states
259 static void initMPIOnce();
260 static void mpiExit();
261 static c10::once_flag onceFlagInitMPI;
262
263 static std::mutex pgGlobalMutex_;
264 static int mpiThreadSupport_;
265
266 MPI_Comm pgComm_;
267};
268
269} // namespace c10d
270
271#endif // USE_C10D_MPI
272