1 | #include <ATen/core/dispatch/Dispatcher.h> |
2 | #include <c10/util/intrusive_ptr.h> |
3 | #include <torch/csrc/distributed/c10d/ProcessGroup.hpp> |
4 | #include <torch/csrc/distributed/c10d/Types.hpp> |
5 | #include <torch/library.h> |
6 | |
7 | namespace c10d { |
8 | namespace { |
9 | |
10 | TORCH_LIBRARY(c10d, m) { |
11 | // The following ProcessGroup, Work, and ReduceOp definitions are more like |
12 | // declarations. They don't expose the details of the two classes into |
13 | // TorchScript. |
14 | m.class_<ProcessGroup>("ProcessGroup" ).def(torch::init<int64_t, int64_t>()); |
15 | m.class_<Work>("Work" ) |
16 | .def(torch::init<>()) |
17 | .def("wait" , [](const c10::intrusive_ptr<Work>& self) { self->wait(); }); |
18 | m.class_<ReduceOp>("ReduceOp" ).def(torch::init<>()); |
19 | m.def( |
20 | "broadcast_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, int root_tensor, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)" ); |
21 | m.def( |
22 | "allreduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)" ); |
23 | m.def( |
24 | "allreduce_coalesced_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int timeout) -> __torch__.torch.classes.c10d.Work" ); |
25 | m.def( |
26 | "allgather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int timeout) -> (Tensor[][], __torch__.torch.classes.c10d.Work)" ); |
27 | m.def( |
28 | "_allgather_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group) -> (Tensor, __torch__.torch.classes.c10d.Work)" ); |
29 | m.def( |
30 | "allgather_coalesced_(Tensor[][] output_lists, Tensor[] input_list, __torch__.torch.classes.c10d.ProcessGroup process_group) -> __torch__.torch.classes.c10d.Work" ); |
31 | m.def( |
32 | "reduce_scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)" ); |
33 | m.def( |
34 | "_reduce_scatter_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int timeout) -> (Tensor, __torch__.torch.classes.c10d.Work)" ); |
35 | m.def( |
36 | "reduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int root_rank, int root_tensor, int timeout) -> __torch__.torch.classes.c10d.Work" ); |
37 | m.def( |
38 | "gather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, int timeout) -> __torch__.torch.classes.c10d.Work" ); |
39 | m.def( |
40 | "scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)" ); |
41 | m.def( |
42 | "alltoall_(Tensor[] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)" ); |
43 | m.def( |
44 | "alltoall_base_(Tensor output, Tensor input, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] output_split_sizes, int[] input_split_sizes, int timeout) -> __torch__.torch.classes.c10d.Work" ); |
45 | m.def( |
46 | "barrier(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] device_ids, int timeout) -> __torch__.torch.classes.c10d.Work" ); |
47 | m.def( |
48 | "monitored_barrier_(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] device_ids, int timeout, bool wait_all_ranks) -> ()" ); |
49 | m.def( |
50 | "send(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int dst, int tag) -> __torch__.torch.classes.c10d.Work" ); |
51 | m.def( |
52 | "recv_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int src, int tag) -> __torch__.torch.classes.c10d.Work" ); |
53 | m.def( |
54 | "recv_any_source_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int tag) -> __torch__.torch.classes.c10d.Work" ); |
55 | } |
56 | } // namespace |
57 | |
58 | namespace ops { |
59 | |
60 | // Below are ProcessGroup's corresponding ops for each backend. Ops are but |
61 | // routed through the dispatcher to be dispatched to the appropriate backend. |
62 | // Currently a no-op as the process group does not have a list of backends. |
63 | c10::intrusive_ptr<Work> send_cpu( |
64 | at::TensorList tensors, |
65 | const c10::intrusive_ptr<ProcessGroup>& process_group, |
66 | int64_t dstRank, |
67 | int64_t tag) { |
68 | auto tensor_vec = tensors.vec(); |
69 | return process_group->getBackend(c10::DeviceType::CPU) |
70 | ->send(tensor_vec, static_cast<int>(dstRank), static_cast<int>(tag)); |
71 | } |
72 | |
73 | c10::intrusive_ptr<Work> send_cuda( |
74 | at::TensorList tensors, |
75 | const c10::intrusive_ptr<ProcessGroup>& process_group, |
76 | int64_t dstRank, |
77 | int64_t tag) { |
78 | auto tensor_vec = tensors.vec(); |
79 | return process_group->getBackend(c10::DeviceType::CUDA) |
80 | ->send(tensor_vec, static_cast<int>(dstRank), static_cast<int>(tag)); |
81 | } |
82 | |
83 | c10::intrusive_ptr<Work> recv_cpu_( |
84 | at::TensorList tensors, |
85 | const c10::intrusive_ptr<ProcessGroup>& process_group, |
86 | int64_t srcRank, |
87 | int64_t tag) { |
88 | auto tensor_vec = tensors.vec(); |
89 | return process_group->getBackend(c10::DeviceType::CPU) |
90 | ->recv(tensor_vec, static_cast<int>(srcRank), static_cast<int>(tag)); |
91 | } |
92 | |
93 | c10::intrusive_ptr<Work> recv_cuda_( |
94 | at::TensorList tensors, |
95 | const c10::intrusive_ptr<ProcessGroup>& process_group, |
96 | int64_t srcRank, |
97 | int64_t tag) { |
98 | auto tensor_vec = tensors.vec(); |
99 | return process_group->getBackend(c10::DeviceType::CUDA) |
100 | ->recv(tensor_vec, static_cast<int>(srcRank), static_cast<int>(tag)); |
101 | } |
102 | |
103 | c10::intrusive_ptr<Work> recv_any_source_cpu_( |
104 | at::TensorList tensors, |
105 | const c10::intrusive_ptr<ProcessGroup>& process_group, |
106 | int64_t tag) { |
107 | auto tensor_vec = tensors.vec(); |
108 | return process_group->getBackend(c10::DeviceType::CPU) |
109 | ->recvAnysource(tensor_vec, static_cast<int>(tag)); |
110 | } |
111 | |
112 | c10::intrusive_ptr<Work> recv_any_source_cuda_( |
113 | at::TensorList tensors, |
114 | const c10::intrusive_ptr<ProcessGroup>& process_group, |
115 | int64_t tag) { |
116 | auto tensor_vec = tensors.vec(); |
117 | return process_group->getBackend(c10::DeviceType::CUDA) |
118 | ->recvAnysource(tensor_vec, static_cast<int>(tag)); |
119 | } |
120 | |
121 | c10::intrusive_ptr<Work> reduce_cpu_( |
122 | at::TensorList tensors, |
123 | const c10::intrusive_ptr<ProcessGroup>& process_group, |
124 | const c10::intrusive_ptr<ReduceOp>& reduce_op, |
125 | int64_t root_rank, |
126 | int64_t root_tensor, |
127 | int64_t timeout) { |
128 | auto tensor_vec = tensors.vec(); |
129 | return process_group->getBackend(c10::DeviceType::CPU) |
130 | ->reduce( |
131 | tensor_vec, |
132 | ReduceOptions{ |
133 | *reduce_op.get(), |
134 | root_rank, |
135 | root_tensor, |
136 | std::chrono::milliseconds(timeout)}); |
137 | } |
138 | |
139 | c10::intrusive_ptr<Work> reduce_cuda_( |
140 | at::TensorList tensors, |
141 | const c10::intrusive_ptr<ProcessGroup>& process_group, |
142 | const c10::intrusive_ptr<ReduceOp>& reduce_op, |
143 | int64_t root_rank, |
144 | int64_t root_tensor, |
145 | int64_t timeout) { |
146 | auto tensor_vec = tensors.vec(); |
147 | return process_group->getBackend(c10::DeviceType::CUDA) |
148 | ->reduce( |
149 | tensor_vec, |
150 | ReduceOptions{ |
151 | *reduce_op.get(), |
152 | root_rank, |
153 | root_tensor, |
154 | std::chrono::milliseconds(timeout)}); |
155 | } |
156 | |
157 | std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> broadcast_cpu_( |
158 | at::TensorList tensors, |
159 | const c10::intrusive_ptr<ProcessGroup>& process_group, |
160 | int64_t root_rank, |
161 | int64_t root_tensor, |
162 | int64_t timeout) { |
163 | auto tensor_vec = tensors.vec(); |
164 | auto work = |
165 | process_group->getBackend(c10::DeviceType::CPU) |
166 | ->broadcast( |
167 | tensor_vec, |
168 | BroadcastOptions{ |
169 | root_rank, root_tensor, std::chrono::milliseconds(timeout)}); |
170 | |
171 | return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( |
172 | std::move(tensor_vec), work); |
173 | } |
174 | |
175 | std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> broadcast_cuda_( |
176 | at::TensorList tensors, |
177 | const c10::intrusive_ptr<ProcessGroup>& process_group, |
178 | int64_t root_rank, |
179 | int64_t root_tensor, |
180 | int64_t timeout) { |
181 | auto tensor_vec = tensors.vec(); |
182 | auto work = |
183 | process_group->getBackend(c10::DeviceType::CUDA) |
184 | ->broadcast( |
185 | tensor_vec, |
186 | BroadcastOptions{ |
187 | root_rank, root_tensor, std::chrono::milliseconds(timeout)}); |
188 | |
189 | return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( |
190 | std::move(tensor_vec), work); |
191 | } |
192 | |
193 | std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> allreduce_cpu_( |
194 | at::TensorList tensors, |
195 | const c10::intrusive_ptr<ProcessGroup>& process_group, |
196 | const c10::intrusive_ptr<ReduceOp>& reduce_op, |
197 | int64_t timeout) { |
198 | auto tensor_vec = tensors.vec(); |
199 | auto work = |
200 | process_group->getBackend(c10::DeviceType::CPU) |
201 | ->allreduce( |
202 | tensor_vec, |
203 | AllreduceOptions{ |
204 | *reduce_op.get(), std::chrono::milliseconds(timeout)}); |
205 | |
206 | // Return input tensors as output tensors to make inplace allreduce look like |
207 | // a functional API, so that make_fx can correctly build the dependencies in |
208 | // the graph later. |
209 | return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( |
210 | std::move(tensor_vec), work); |
211 | } |
212 | |
213 | std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> allreduce_cuda_( |
214 | at::TensorList tensors, |
215 | const c10::intrusive_ptr<ProcessGroup>& process_group, |
216 | const c10::intrusive_ptr<ReduceOp>& reduce_op, |
217 | int64_t timeout) { |
218 | auto tensor_vec = tensors.vec(); |
219 | auto work = |
220 | process_group->getBackend(c10::DeviceType::CUDA) |
221 | ->allreduce( |
222 | tensor_vec, |
223 | AllreduceOptions{ |
224 | *reduce_op.get(), std::chrono::milliseconds(timeout)}); |
225 | |
226 | // Return input tensors as output tensors to make inplace allreduce look like |
227 | // a functional API, so that make_fx can correctly build the dependencies in |
228 | // the graph later. |
229 | return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( |
230 | std::move(tensor_vec), work); |
231 | } |
232 | |
233 | c10::intrusive_ptr<Work> allreduce_coalesced_cpu_( |
234 | at::TensorList tensors, |
235 | const c10::intrusive_ptr<ProcessGroup>& process_group, |
236 | const c10::intrusive_ptr<ReduceOp>& reduce_op, |
237 | int64_t timeout) { |
238 | auto tensor_vec = tensors.vec(); |
239 | AllreduceCoalescedOptions opts = AllreduceCoalescedOptions{}; |
240 | opts.reduceOp = *reduce_op.get(); |
241 | opts.timeout = std::chrono::milliseconds(timeout); |
242 | |
243 | return process_group->getBackend(c10::DeviceType::CPU) |
244 | ->allreduce_coalesced(tensor_vec, opts); |
245 | } |
246 | |
247 | c10::intrusive_ptr<Work> allreduce_coalesced_cuda_( |
248 | at::TensorList tensors, |
249 | const c10::intrusive_ptr<ProcessGroup>& process_group, |
250 | const c10::intrusive_ptr<ReduceOp>& reduce_op, |
251 | int64_t timeout) { |
252 | auto tensor_vec = tensors.vec(); |
253 | AllreduceCoalescedOptions opts = AllreduceCoalescedOptions{}; |
254 | opts.reduceOp = *reduce_op.get(); |
255 | opts.timeout = std::chrono::milliseconds(timeout); |
256 | |
257 | return process_group->getBackend(c10::DeviceType::CUDA) |
258 | ->allreduce_coalesced(tensor_vec, opts); |
259 | } |
260 | |
261 | std::tuple<std::vector<std::vector<at::Tensor>>, c10::intrusive_ptr<Work>> |
262 | allgather_cpu_( |
263 | const std::vector<std::vector<at::Tensor>>& output_tensors, |
264 | at::TensorList input_tensors, |
265 | const c10::intrusive_ptr<ProcessGroup>& process_group, |
266 | int64_t timeout) { |
267 | auto input_tensors_vec = input_tensors.vec(); |
268 | auto work = |
269 | process_group->getBackend(c10::DeviceType::CPU) |
270 | ->allgather( |
271 | const_cast<std::vector<std::vector<at::Tensor>>&>(output_tensors), |
272 | input_tensors_vec, |
273 | AllgatherOptions{std::chrono::milliseconds(timeout)}); |
274 | |
275 | // Copy output tensors (not storage) so that this can be used in a functional |
276 | // manner |
277 | return std:: |
278 | tuple<std::vector<std::vector<at::Tensor>>, c10::intrusive_ptr<Work>>( |
279 | output_tensors, work); |
280 | } |
281 | |
282 | std::tuple<std::vector<std::vector<at::Tensor>>, c10::intrusive_ptr<Work>> |
283 | allgather_cuda_( |
284 | const std::vector<std::vector<at::Tensor>>& output_tensors, |
285 | at::TensorList input_tensors, |
286 | const c10::intrusive_ptr<ProcessGroup>& process_group, |
287 | int64_t timeout) { |
288 | auto input_tensors_vec = input_tensors.vec(); |
289 | auto work = |
290 | process_group->getBackend(c10::DeviceType::CUDA) |
291 | ->allgather( |
292 | const_cast<std::vector<std::vector<at::Tensor>>&>(output_tensors), |
293 | input_tensors_vec, |
294 | AllgatherOptions{std::chrono::milliseconds(timeout)}); |
295 | |
296 | // Copy output tensors (not storage) so that this can be used in a functional |
297 | // manner |
298 | return std:: |
299 | tuple<std::vector<std::vector<at::Tensor>>, c10::intrusive_ptr<Work>>( |
300 | output_tensors, work); |
301 | } |
302 | |
303 | std::tuple<at::Tensor, c10::intrusive_ptr<Work>> _allgather_base_cpu_( |
304 | at::Tensor& output_tensor, |
305 | at::Tensor& input_tensor, |
306 | const c10::intrusive_ptr<ProcessGroup>& process_group) { |
307 | auto work = process_group->getBackend(c10::DeviceType::CPU) |
308 | ->_allgather_base(output_tensor, input_tensor); |
309 | |
310 | return std::tuple<at::Tensor, c10::intrusive_ptr<Work>>(output_tensor, work); |
311 | } |
312 | |
313 | std::tuple<at::Tensor, c10::intrusive_ptr<Work>> _allgather_base_cuda_( |
314 | at::Tensor& output_tensor, |
315 | at::Tensor& input_tensor, |
316 | const c10::intrusive_ptr<ProcessGroup>& process_group) { |
317 | auto work = process_group->getBackend(c10::DeviceType::CUDA) |
318 | ->_allgather_base(output_tensor, input_tensor); |
319 | |
320 | return std::tuple<at::Tensor, c10::intrusive_ptr<Work>>(output_tensor, work); |
321 | } |
322 | |
323 | c10::intrusive_ptr<Work> allgather_coalesced_cpu_( |
324 | const std::vector<std::vector<at::Tensor>>& output_lists, |
325 | const at::TensorList& input_list, |
326 | const c10::intrusive_ptr<ProcessGroup>& process_group) { |
327 | auto input_list_vec = input_list.vec(); |
328 | return process_group->getBackend(c10::DeviceType::CPU) |
329 | ->allgather_coalesced( |
330 | const_cast<std::vector<std::vector<at::Tensor>>&>(output_lists), |
331 | input_list_vec); |
332 | } |
333 | |
334 | c10::intrusive_ptr<Work> allgather_coalesced_cuda_( |
335 | const std::vector<std::vector<at::Tensor>>& output_lists, |
336 | const at::TensorList& input_list, |
337 | const c10::intrusive_ptr<ProcessGroup>& process_group) { |
338 | auto input_list_vec = input_list.vec(); |
339 | return process_group->getBackend(c10::DeviceType::CUDA) |
340 | ->allgather_coalesced( |
341 | const_cast<std::vector<std::vector<at::Tensor>>&>(output_lists), |
342 | input_list_vec); |
343 | } |
344 | |
345 | std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> |
346 | reduce_scatter_cpu_( |
347 | const at::TensorList& output_tensors, |
348 | const std::vector<std::vector<at::Tensor>>& input_tensors, |
349 | const c10::intrusive_ptr<ProcessGroup>& process_group, |
350 | const c10::intrusive_ptr<ReduceOp>& reduce_op, |
351 | int64_t timeout) { |
352 | auto output_tensors_vec = output_tensors.vec(); |
353 | auto work = |
354 | process_group->getBackend(c10::DeviceType::CPU) |
355 | ->reduce_scatter( |
356 | output_tensors_vec, |
357 | const_cast<std::vector<std::vector<at::Tensor>>&>(input_tensors), |
358 | ReduceScatterOptions{ |
359 | *reduce_op.get(), std::chrono::milliseconds(timeout)}); |
360 | |
361 | return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( |
362 | output_tensors_vec, work); |
363 | } |
364 | |
365 | std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> |
366 | reduce_scatter_cuda_( |
367 | const at::TensorList& output_tensors, |
368 | const std::vector<std::vector<at::Tensor>>& input_tensors, |
369 | const c10::intrusive_ptr<ProcessGroup>& process_group, |
370 | const c10::intrusive_ptr<ReduceOp>& reduce_op, |
371 | int64_t timeout) { |
372 | auto output_tensors_vec = output_tensors.vec(); |
373 | auto work = |
374 | process_group->getBackend(c10::DeviceType::CUDA) |
375 | ->reduce_scatter( |
376 | output_tensors_vec, |
377 | const_cast<std::vector<std::vector<at::Tensor>>&>(input_tensors), |
378 | ReduceScatterOptions{ |
379 | *reduce_op.get(), std::chrono::milliseconds(timeout)}); |
380 | |
381 | return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( |
382 | output_tensors_vec, work); |
383 | } |
384 | |
385 | std::tuple<at::Tensor, c10::intrusive_ptr<Work>> _reduce_scatter_base_cpu_( |
386 | at::Tensor& output_tensor, |
387 | at::Tensor& input_tensor, |
388 | const c10::intrusive_ptr<ProcessGroup>& process_group, |
389 | const c10::intrusive_ptr<ReduceOp>& reduce_op, |
390 | int64_t timeout) { |
391 | auto work = |
392 | process_group->getBackend(c10::DeviceType::CPU) |
393 | ->_reduce_scatter_base( |
394 | output_tensor, |
395 | input_tensor, |
396 | ReduceScatterOptions{ |
397 | *reduce_op.get(), std::chrono::milliseconds(timeout)}); |
398 | |
399 | return std::tuple<at::Tensor, c10::intrusive_ptr<Work>>(output_tensor, work); |
400 | } |
401 | |
402 | std::tuple<at::Tensor, c10::intrusive_ptr<Work>> _reduce_scatter_base_cuda_( |
403 | at::Tensor& output_tensor, |
404 | at::Tensor& input_tensor, |
405 | const c10::intrusive_ptr<ProcessGroup>& process_group, |
406 | const c10::intrusive_ptr<ReduceOp>& reduce_op, |
407 | int64_t timeout) { |
408 | auto work = |
409 | process_group->getBackend(c10::DeviceType::CUDA) |
410 | ->_reduce_scatter_base( |
411 | output_tensor, |
412 | input_tensor, |
413 | ReduceScatterOptions{ |
414 | *reduce_op.get(), std::chrono::milliseconds(timeout)}); |
415 | |
416 | return std::tuple<at::Tensor, c10::intrusive_ptr<Work>>(output_tensor, work); |
417 | } |
418 | |
419 | c10::intrusive_ptr<Work> gather_cpu_( |
420 | const std::vector<std::vector<at::Tensor>>& output_tensors, |
421 | const at::TensorList& input_tensors, |
422 | const c10::intrusive_ptr<ProcessGroup>& process_group, |
423 | int64_t root_rank, |
424 | int64_t timeout) { |
425 | auto input_tensors_vec = input_tensors.vec(); |
426 | return process_group->getBackend(c10::DeviceType::CPU) |
427 | ->gather( |
428 | const_cast<std::vector<std::vector<at::Tensor>>&>(output_tensors), |
429 | input_tensors_vec, |
430 | GatherOptions{root_rank, std::chrono::milliseconds(timeout)}); |
431 | } |
432 | c10::intrusive_ptr<Work> gather_cuda_( |
433 | const std::vector<std::vector<at::Tensor>>& output_tensors, |
434 | const at::TensorList& input_tensors, |
435 | const c10::intrusive_ptr<ProcessGroup>& process_group, |
436 | int64_t root_rank, |
437 | int64_t timeout) { |
438 | auto input_tensors_vec = input_tensors.vec(); |
439 | return process_group->getBackend(c10::DeviceType::CUDA) |
440 | ->gather( |
441 | const_cast<std::vector<std::vector<at::Tensor>>&>(output_tensors), |
442 | input_tensors_vec, |
443 | GatherOptions{root_rank, std::chrono::milliseconds(timeout)}); |
444 | } |
445 | |
446 | std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> scatter_cpu_( |
447 | const at::TensorList& output_tensors, |
448 | const std::vector<std::vector<at::Tensor>>& input_tensors, |
449 | const c10::intrusive_ptr<ProcessGroup>& process_group, |
450 | int64_t root_rank, |
451 | int64_t timeout) { |
452 | auto output_tensors_vec = output_tensors.vec(); |
453 | auto work = |
454 | process_group->getBackend(c10::DeviceType::CPU) |
455 | ->scatter( |
456 | output_tensors_vec, |
457 | const_cast<std::vector<std::vector<at::Tensor>>&>(input_tensors), |
458 | ScatterOptions{root_rank, std::chrono::milliseconds(timeout)}); |
459 | |
460 | return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( |
461 | std::move(output_tensors_vec), work); |
462 | } |
463 | |
464 | std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> scatter_cuda_( |
465 | const at::TensorList& output_tensors, |
466 | const std::vector<std::vector<at::Tensor>>& input_tensors, |
467 | const c10::intrusive_ptr<ProcessGroup>& process_group, |
468 | int64_t root_rank, |
469 | int64_t timeout) { |
470 | auto output_tensors_vec = output_tensors.vec(); |
471 | auto work = |
472 | process_group->getBackend(c10::DeviceType::CUDA) |
473 | ->scatter( |
474 | output_tensors_vec, |
475 | const_cast<std::vector<std::vector<at::Tensor>>&>(input_tensors), |
476 | ScatterOptions{root_rank, std::chrono::milliseconds(timeout)}); |
477 | |
478 | return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( |
479 | std::move(output_tensors_vec), work); |
480 | } |
481 | |
482 | std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> alltoall_cpu_( |
483 | const at::TensorList& output_tensors, |
484 | const at::TensorList& input_tensors, |
485 | const c10::intrusive_ptr<ProcessGroup>& process_group, |
486 | int64_t timeout) { |
487 | auto output_tensors_vec = output_tensors.vec(); |
488 | auto input_tensors_vec = input_tensors.vec(); |
489 | auto work = process_group->getBackend(c10::DeviceType::CPU) |
490 | ->alltoall( |
491 | output_tensors_vec, |
492 | input_tensors_vec, |
493 | AllToAllOptions{std::chrono::milliseconds(timeout)}); |
494 | return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( |
495 | std::move(output_tensors_vec), work); |
496 | } |
497 | |
498 | std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> alltoall_cuda_( |
499 | const at::TensorList& output_tensors, |
500 | const at::TensorList& input_tensors, |
501 | const c10::intrusive_ptr<ProcessGroup>& process_group, |
502 | int64_t timeout) { |
503 | auto output_tensors_vec = output_tensors.vec(); |
504 | auto input_tensors_vec = input_tensors.vec(); |
505 | auto work = process_group->getBackend(c10::DeviceType::CUDA) |
506 | ->alltoall( |
507 | output_tensors_vec, |
508 | input_tensors_vec, |
509 | AllToAllOptions{std::chrono::milliseconds(timeout)}); |
510 | return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( |
511 | std::move(output_tensors_vec), work); |
512 | } |
513 | |
514 | c10::intrusive_ptr<Work> alltoall_base_cpu_( |
515 | at::Tensor& output, |
516 | at::Tensor& input, |
517 | const c10::intrusive_ptr<ProcessGroup>& process_group, |
518 | std::vector<int64_t> output_split_sizes, |
519 | std::vector<int64_t> input_split_sizes, |
520 | int64_t timeout) { |
521 | return process_group->getBackend(c10::DeviceType::CPU) |
522 | ->alltoall_base( |
523 | output, |
524 | input, |
525 | output_split_sizes, |
526 | input_split_sizes, |
527 | AllToAllOptions{std::chrono::milliseconds(timeout)}); |
528 | } |
529 | |
530 | c10::intrusive_ptr<Work> alltoall_base_cuda_( |
531 | at::Tensor& output, |
532 | at::Tensor& input, |
533 | const c10::intrusive_ptr<ProcessGroup>& process_group, |
534 | std::vector<int64_t> output_split_sizes, |
535 | std::vector<int64_t> input_split_sizes, |
536 | int64_t timeout) { |
537 | return process_group->getBackend(c10::DeviceType::CUDA) |
538 | ->alltoall_base( |
539 | output, |
540 | input, |
541 | output_split_sizes, |
542 | input_split_sizes, |
543 | AllToAllOptions{std::chrono::milliseconds(timeout)}); |
544 | } |
545 | |
546 | c10::intrusive_ptr<Work> barrier_cpu( |
547 | at::Tensor /* unused */, |
548 | const c10::intrusive_ptr<ProcessGroup>& process_group, |
549 | const std::vector<int64_t>& device_ids, |
550 | int64_t timeout) { |
551 | return process_group->getBackend(c10::DeviceType::CPU) |
552 | ->barrier(BarrierOptions{device_ids, std::chrono::milliseconds(timeout)}); |
553 | } |
554 | |
555 | c10::intrusive_ptr<Work> barrier_cuda( |
556 | at::Tensor /* unused */, |
557 | const c10::intrusive_ptr<ProcessGroup>& process_group, |
558 | const std::vector<int64_t>& device_ids, |
559 | int64_t timeout) { |
560 | return process_group->getBackend(c10::DeviceType::CUDA) |
561 | ->barrier(BarrierOptions{device_ids, std::chrono::milliseconds(timeout)}); |
562 | } |
563 | |
564 | void monitored_barrier_cpu_( |
565 | at::Tensor /* unused */, |
566 | const c10::intrusive_ptr<::c10d::ProcessGroup>& process_group, |
567 | const std::vector<int64_t>& device_ids, |
568 | int64_t timeout, |
569 | bool wait_all_ranks) { |
570 | process_group->getBackend(c10::DeviceType::CPU) |
571 | ->monitoredBarrier( |
572 | BarrierOptions{device_ids, std::chrono::milliseconds(timeout)}, |
573 | wait_all_ranks); |
574 | } |
575 | |
576 | // register functions to dispatcher |
577 | namespace { |
578 | TORCH_LIBRARY_IMPL(c10d, CPU, m) { |
579 | m.impl("send" , send_cpu); |
580 | } |
581 | |
582 | TORCH_LIBRARY_IMPL(c10d, CUDA, m) { |
583 | m.impl("send" , send_cuda); |
584 | } |
585 | |
586 | TORCH_LIBRARY_IMPL(c10d, CPU, m) { |
587 | m.impl("recv_" , recv_cpu_); |
588 | } |
589 | |
590 | TORCH_LIBRARY_IMPL(c10d, CUDA, m) { |
591 | m.impl("recv_" , recv_cuda_); |
592 | } |
593 | |
594 | TORCH_LIBRARY_IMPL(c10d, CPU, m) { |
595 | m.impl("recv_any_source_" , recv_any_source_cpu_); |
596 | } |
597 | |
598 | TORCH_LIBRARY_IMPL(c10d, CUDA, m) { |
599 | m.impl("recv_any_source_" , recv_any_source_cuda_); |
600 | } |
601 | |
602 | TORCH_LIBRARY_IMPL(c10d, CPU, m) { |
603 | m.impl("reduce_" , reduce_cpu_); |
604 | } |
605 | |
606 | TORCH_LIBRARY_IMPL(c10d, CUDA, m) { |
607 | m.impl("reduce_" , reduce_cuda_); |
608 | } |
609 | |
610 | TORCH_LIBRARY_IMPL(c10d, CPU, m) { |
611 | m.impl("broadcast_" , broadcast_cpu_); |
612 | } |
613 | |
614 | TORCH_LIBRARY_IMPL(c10d, CUDA, m) { |
615 | m.impl("broadcast_" , broadcast_cuda_); |
616 | } |
617 | |
618 | TORCH_LIBRARY_IMPL(c10d, CPU, m) { |
619 | m.impl("allreduce_" , allreduce_cpu_); |
620 | } |
621 | |
622 | // TODO: The SparseCPU/SparseCUDA dispatched methods are only used to support |
623 | // sparse all_reduce in the Gloo backend |
624 | TORCH_LIBRARY_IMPL(c10d, SparseCPU, m) { |
625 | m.impl("allreduce_" , allreduce_cpu_); |
626 | } |
627 | |
628 | TORCH_LIBRARY_IMPL(c10d, SparseCUDA, m) { |
629 | m.impl("allreduce_" , allreduce_cuda_); |
630 | } |
631 | |
632 | TORCH_LIBRARY_IMPL(c10d, CUDA, m) { |
633 | m.impl("allreduce_" , allreduce_cuda_); |
634 | } |
635 | |
636 | TORCH_LIBRARY_IMPL(c10d, CPU, m) { |
637 | m.impl("allreduce_coalesced_" , allreduce_coalesced_cpu_); |
638 | } |
639 | |
640 | TORCH_LIBRARY_IMPL(c10d, CUDA, m) { |
641 | m.impl("allreduce_coalesced_" , allreduce_coalesced_cuda_); |
642 | } |
643 | |
644 | TORCH_LIBRARY_IMPL(c10d, CPU, m) { |
645 | m.impl("allgather_" , allgather_cpu_); |
646 | } |
647 | |
648 | TORCH_LIBRARY_IMPL(c10d, CUDA, m) { |
649 | m.impl("allgather_" , allgather_cuda_); |
650 | } |
651 | |
652 | TORCH_LIBRARY_IMPL(c10d, CPU, m) { |
653 | m.impl("_allgather_base_" , _allgather_base_cpu_); |
654 | } |
655 | |
656 | TORCH_LIBRARY_IMPL(c10d, CUDA, m) { |
657 | m.impl("_allgather_base_" , _allgather_base_cuda_); |
658 | } |
659 | |
660 | TORCH_LIBRARY_IMPL(c10d, CPU, m) { |
661 | m.impl("allgather_coalesced_" , allgather_coalesced_cpu_); |
662 | } |
663 | |
664 | TORCH_LIBRARY_IMPL(c10d, CUDA, m) { |
665 | m.impl("allgather_coalesced_" , allgather_coalesced_cuda_); |
666 | } |
667 | |
668 | TORCH_LIBRARY_IMPL(c10d, CPU, m) { |
669 | m.impl("reduce_scatter_" , reduce_scatter_cpu_); |
670 | } |
671 | |
672 | TORCH_LIBRARY_IMPL(c10d, CUDA, m) { |
673 | m.impl("reduce_scatter_" , reduce_scatter_cuda_); |
674 | } |
675 | |
676 | TORCH_LIBRARY_IMPL(c10d, CPU, m) { |
677 | m.impl("_reduce_scatter_base_" , _reduce_scatter_base_cpu_); |
678 | } |
679 | |
680 | TORCH_LIBRARY_IMPL(c10d, CUDA, m) { |
681 | m.impl("_reduce_scatter_base_" , _reduce_scatter_base_cuda_); |
682 | } |
683 | |
684 | TORCH_LIBRARY_IMPL(c10d, CPU, m) { |
685 | m.impl("gather_" , gather_cpu_); |
686 | } |
687 | |
688 | TORCH_LIBRARY_IMPL(c10d, CUDA, m) { |
689 | m.impl("gather_" , gather_cuda_); |
690 | } |
691 | |
692 | TORCH_LIBRARY_IMPL(c10d, CPU, m) { |
693 | m.impl("scatter_" , scatter_cpu_); |
694 | } |
695 | |
696 | TORCH_LIBRARY_IMPL(c10d, CUDA, m) { |
697 | m.impl("scatter_" , scatter_cuda_); |
698 | } |
699 | |
700 | TORCH_LIBRARY_IMPL(c10d, CPU, m) { |
701 | m.impl("alltoall_" , alltoall_cpu_); |
702 | } |
703 | |
704 | TORCH_LIBRARY_IMPL(c10d, CUDA, m) { |
705 | m.impl("alltoall_" , alltoall_cuda_); |
706 | } |
707 | |
708 | TORCH_LIBRARY_IMPL(c10d, CPU, m) { |
709 | m.impl("alltoall_base_" , alltoall_base_cpu_); |
710 | } |
711 | |
712 | TORCH_LIBRARY_IMPL(c10d, CUDA, m) { |
713 | m.impl("alltoall_base_" , alltoall_base_cuda_); |
714 | } |
715 | |
716 | TORCH_LIBRARY_IMPL(c10d, CPU, m) { |
717 | m.impl("barrier" , barrier_cpu); |
718 | } |
719 | |
720 | TORCH_LIBRARY_IMPL(c10d, CUDA, m) { |
721 | m.impl("barrier" , barrier_cuda); |
722 | } |
723 | |
724 | TORCH_LIBRARY_IMPL(c10d, CPU, m) { |
725 | m.impl("monitored_barrier_" , monitored_barrier_cpu_); |
726 | } |
727 | |
728 | } // namespace |
729 | |
730 | } // namespace ops |
731 | } // namespace c10d |
732 | |