1 | #pragma once |
2 | |
3 | #ifdef USE_C10D_MPI |
4 | |
5 | #include <condition_variable> |
6 | #include <deque> |
7 | #include <exception> |
8 | #include <memory> |
9 | #include <mutex> |
10 | #include <thread> |
11 | #include <vector> |
12 | |
13 | #include <ATen/core/ivalue.h> |
14 | #include <ATen/core/ivalue_inl.h> |
15 | |
16 | #include <torch/csrc/distributed/c10d/Backend.hpp> |
17 | #include <torch/csrc/distributed/c10d/Types.hpp> |
18 | #include <torch/csrc/distributed/c10d/Utils.hpp> |
19 | |
20 | #include <c10/util/CallOnce.h> |
21 | |
22 | #include <mpi.h> |
23 | |
24 | namespace c10d { |
25 | |
26 | constexpr const char* MPI_BACKEND_NAME = "mpi" ; |
27 | |
28 | // WorkEntry is the state associated with a single MPI run instance. |
29 | // It include the source Tensor list and destination Tensor list, as well as |
30 | // The actual run function that will operate either on src or dst or both. |
31 | struct WorkEntry { |
32 | explicit WorkEntry( |
33 | std::vector<at::Tensor>* srcPtr, |
34 | std::vector<at::Tensor>* dstPtr, |
35 | std::function<void(std::unique_ptr<WorkEntry>&)> run) |
36 | : dst(dstPtr ? *dstPtr : std::vector<at::Tensor>()), |
37 | run(std::move(run)) { |
38 | if (srcPtr) { |
39 | src = *srcPtr; |
40 | } |
41 | } |
42 | |
43 | // Not copyable |
44 | WorkEntry(const WorkEntry&) = delete; |
45 | // Not copy assignable |
46 | WorkEntry& operator=(const WorkEntry&) = delete; |
47 | |
48 | // For input and output tensors (in-place), we will always use src |
49 | std::vector<at::Tensor> src; |
50 | |
51 | // Copy of user provided outputs. |
52 | const std::vector<at::Tensor> dst; |
53 | |
54 | // src rank returned, for recv only |
55 | int* srcRank = nullptr; |
56 | std::function<void(std::unique_ptr<WorkEntry>&)> run; |
57 | }; |
58 | |
59 | // ProcessGroupMPI implements MPI bindings for c10d. |
60 | // |
61 | // All functions on this class are expected to be called in the same |
62 | // order across processes in the group. This is the only way that we |
63 | // can guarantee to match up the same calls across processes. |
64 | // |
65 | // All MPI functions provided by this class is asynchronously scheduled on a |
66 | // Worker thread. Therefore, ProcessGroupMPI requires the MPI implementation |
67 | // that is used to have a minimum thread support value of MPI_THREAD_SERIALIZED. |
68 | // That is, The process may be multi-threaded, and multiple threads may make |
69 | // MPI calls, but only one at a time: MPI calls are not made concurrently from |
70 | // two distinct threads (all MPI calls are serialized). However, with |
71 | // MPI_THREAD_SERIALIZED, ProcessGroupMPI will only support a singe process |
72 | // group. In other words, no more than 1 process group can be created globally. |
73 | // |
74 | // If you would like to use multiple ProcessGroupMPI, it requres your MPI |
75 | // implemenation to have a thread support value of MPI_THREAD_MULTIPLE, that is, |
76 | // multiple threads may call MPI, with no restriction. |
77 | // |
78 | // Also note that ProcessGroupMPI only supports a single Tensor operation. In |
79 | // other words, the size of the input Tensor vector should always be 1. |
80 | // |
81 | // CUDA tensor can be supported if the MPI used is CUDA-aware MPI, and |
82 | // ProcessGroupMPI will automatically detect this support. |
83 | class TORCH_API ProcessGroupMPI : public Backend { |
84 | public: |
85 | class WorkMPI : public Work { |
86 | public: |
87 | explicit WorkMPI( |
88 | std::vector<at::Tensor> outputTensors, |
89 | const char* profilingTitle = nullptr, |
90 | const c10::optional<std::vector<at::Tensor>>& inputTensors = |
91 | c10::nullopt) |
92 | : Work(-1, OpType::UNKNOWN, profilingTitle, inputTensors), |
93 | outputTensors_(std::move(outputTensors)), |
94 | future_(c10::make_intrusive<at::ivalue::Future>( |
95 | c10::ListType::create(c10::TensorType::get()))) {} |
96 | |
97 | std::vector<at::Tensor> result() override; |
98 | |
99 | c10::intrusive_ptr<c10::ivalue::Future> getFuture() override; |
100 | |
101 | protected: |
102 | friend class ProcessGroupMPI; |
103 | |
104 | private: |
105 | void finishWorkMPI(); |
106 | void finishWorkMPIError(std::exception_ptr eptr); |
107 | |
108 | std::vector<at::Tensor> outputTensors_; |
109 | c10::intrusive_ptr<at::ivalue::Future> future_; |
110 | }; |
111 | |
112 | class AsyncWork : public Work { |
113 | public: |
114 | AsyncWork( |
115 | MPI_Request request, |
116 | std::vector<at::Tensor> outputTensors, |
117 | const char* profilingTitle = nullptr, |
118 | const c10::optional<std::vector<at::Tensor>>& inputTensors = |
119 | c10::nullopt); |
120 | |
121 | virtual ~AsyncWork(); |
122 | |
123 | bool isCompleted() override; |
124 | |
125 | bool isSuccess() const override; |
126 | |
127 | int sourceRank() const override; |
128 | |
129 | bool wait(std::chrono::milliseconds timeout = kUnsetTimeout) override; |
130 | |
131 | void abort() override; |
132 | |
133 | std::vector<at::Tensor> result() override; |
134 | |
135 | protected: |
136 | void populateException(); |
137 | |
138 | private: |
139 | const std::vector<at::Tensor> outputTensors_; |
140 | MPI_Request request_; |
141 | MPI_Status status_; |
142 | }; |
143 | |
144 | // Constructor will spawn up the worker thread loop |
145 | explicit ProcessGroupMPI(int rank, int size, MPI_Comm pgComm); |
146 | |
147 | virtual ~ProcessGroupMPI(); |
148 | |
149 | // Abort the MPI program, needs to be called when exception is detected |
150 | void abort(); |
151 | |
152 | const std::string getBackendName() const override { |
153 | return std::string(MPI_BACKEND_NAME); |
154 | } |
155 | |
156 | c10::intrusive_ptr<Work> broadcast( |
157 | std::vector<at::Tensor>& data, |
158 | const BroadcastOptions& opts = BroadcastOptions()) override; |
159 | |
160 | c10::intrusive_ptr<Work> allreduce( |
161 | std::vector<at::Tensor>& tensors, |
162 | const AllreduceOptions& opts = AllreduceOptions()) override; |
163 | |
164 | c10::intrusive_ptr<Work> allreduce_coalesced( |
165 | std::vector<at::Tensor>& tensors, |
166 | const AllreduceCoalescedOptions& opts = |
167 | AllreduceCoalescedOptions()) override; |
168 | |
169 | c10::intrusive_ptr<Work> reduce( |
170 | std::vector<at::Tensor>& tensors, |
171 | const ReduceOptions& opts = ReduceOptions()) override; |
172 | |
173 | c10::intrusive_ptr<Work> allgather( |
174 | std::vector<std::vector<at::Tensor>>& outputTensors, |
175 | std::vector<at::Tensor>& inputTensors, |
176 | const AllgatherOptions& opts = AllgatherOptions()) override; |
177 | |
178 | c10::intrusive_ptr<Work> _allgather_base( |
179 | at::Tensor& outputbuffer, |
180 | at::Tensor& inputbuffer, |
181 | const AllgatherOptions& opts = AllgatherOptions()) override; |
182 | |
183 | c10::intrusive_ptr<Work> allgather_coalesced( |
184 | std::vector<std::vector<at::Tensor>>& outputTensorLists, |
185 | std::vector<at::Tensor>& inputTensors, |
186 | const AllgatherOptions& opts = AllgatherOptions()) override; |
187 | |
188 | c10::intrusive_ptr<Work> gather( |
189 | std::vector<std::vector<at::Tensor>>& outputTensors, |
190 | std::vector<at::Tensor>& inputTensors, |
191 | const GatherOptions& opts = GatherOptions()) override; |
192 | |
193 | c10::intrusive_ptr<Work> scatter( |
194 | std::vector<at::Tensor>& outputTensors, |
195 | std::vector<std::vector<at::Tensor>>& inputTensors, |
196 | const ScatterOptions& opts = ScatterOptions()) override; |
197 | |
198 | c10::intrusive_ptr<Work> reduce_scatter( |
199 | std::vector<at::Tensor>& outputTensors, |
200 | std::vector<std::vector<at::Tensor>>& inputTensors, |
201 | const ReduceScatterOptions& opts = ReduceScatterOptions()) override; |
202 | |
203 | c10::intrusive_ptr<Work> alltoall_base( |
204 | at::Tensor& outputTensor, |
205 | at::Tensor& inputTensor, |
206 | std::vector<int64_t>& outputSplitSizes, |
207 | std::vector<int64_t>& inputSplitSizes, |
208 | const AllToAllOptions& opts = AllToAllOptions()) override; |
209 | |
210 | c10::intrusive_ptr<Work> alltoall( |
211 | std::vector<at::Tensor>& outputTensors, |
212 | std::vector<at::Tensor>& inputTensors, |
213 | const AllToAllOptions& opts = AllToAllOptions()) override; |
214 | |
215 | c10::intrusive_ptr<Work> send( |
216 | std::vector<at::Tensor>& tensors, |
217 | int dstRank, |
218 | int tag) override; |
219 | |
220 | c10::intrusive_ptr<Work> recv( |
221 | std::vector<at::Tensor>& tensors, |
222 | int srcRank, |
223 | int tag) override; |
224 | |
225 | c10::intrusive_ptr<Work> recvAnysource( |
226 | std::vector<at::Tensor>& tensor, |
227 | int tag) override; |
228 | |
229 | c10::intrusive_ptr<Work> barrier( |
230 | const BarrierOptions& opts = BarrierOptions()) override; |
231 | |
232 | // Creating a new ProcessGroupMPI, will initiialize MPI if not initialized |
233 | static c10::intrusive_ptr<ProcessGroupMPI> createProcessGroupMPI( |
234 | std::vector<int> ranks = {}); |
235 | |
236 | protected: |
237 | using WorkType = |
238 | std::tuple<std::unique_ptr<WorkEntry>, c10::intrusive_ptr<WorkMPI>>; |
239 | // Worker thread loop |
240 | void runLoop(); |
241 | // Helper function that is called by the destructor |
242 | void destroy(); |
243 | |
244 | c10::intrusive_ptr<Work> enqueue( |
245 | std::unique_ptr<WorkEntry> entry, |
246 | const char* profilingTitle = nullptr, |
247 | const c10::optional<std::vector<at::Tensor>>& inputTensors = c10::nullopt); |
248 | |
249 | bool stop_; |
250 | |
251 | std::mutex pgMutex_; |
252 | std::thread workerThread_; |
253 | |
254 | std::deque<WorkType> queue_; |
255 | std::condition_variable queueProduceCV_; |
256 | std::condition_variable queueConsumeCV_; |
257 | |
258 | // Global states |
259 | static void initMPIOnce(); |
260 | static void mpiExit(); |
261 | static c10::once_flag onceFlagInitMPI; |
262 | |
263 | static std::mutex pgGlobalMutex_; |
264 | static int mpiThreadSupport_; |
265 | |
266 | MPI_Comm pgComm_; |
267 | }; |
268 | |
269 | } // namespace c10d |
270 | |
271 | #endif // USE_C10D_MPI |
272 | |