1#include <ATen/core/dispatch/Dispatcher.h>
2#include <c10/util/intrusive_ptr.h>
3#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
4#include <torch/csrc/distributed/c10d/Types.hpp>
5#include <torch/library.h>
6
7namespace c10d {
8namespace {
9
10TORCH_LIBRARY(c10d, m) {
11 // The following ProcessGroup, Work, and ReduceOp definitions are more like
12 // declarations. They don't expose the details of the two classes into
13 // TorchScript.
14 m.class_<ProcessGroup>("ProcessGroup").def(torch::init<int64_t, int64_t>());
15 m.class_<Work>("Work")
16 .def(torch::init<>())
17 .def("wait", [](const c10::intrusive_ptr<Work>& self) { self->wait(); });
18 m.class_<ReduceOp>("ReduceOp").def(torch::init<>());
19 m.def(
20 "broadcast_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, int root_tensor, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
21 m.def(
22 "allreduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
23 m.def(
24 "allreduce_coalesced_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int timeout) -> __torch__.torch.classes.c10d.Work");
25 m.def(
26 "allgather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int timeout) -> (Tensor[][], __torch__.torch.classes.c10d.Work)");
27 m.def(
28 "_allgather_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group) -> (Tensor, __torch__.torch.classes.c10d.Work)");
29 m.def(
30 "allgather_coalesced_(Tensor[][] output_lists, Tensor[] input_list, __torch__.torch.classes.c10d.ProcessGroup process_group) -> __torch__.torch.classes.c10d.Work");
31 m.def(
32 "reduce_scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
33 m.def(
34 "_reduce_scatter_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int timeout) -> (Tensor, __torch__.torch.classes.c10d.Work)");
35 m.def(
36 "reduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int root_rank, int root_tensor, int timeout) -> __torch__.torch.classes.c10d.Work");
37 m.def(
38 "gather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, int timeout) -> __torch__.torch.classes.c10d.Work");
39 m.def(
40 "scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
41 m.def(
42 "alltoall_(Tensor[] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
43 m.def(
44 "alltoall_base_(Tensor output, Tensor input, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] output_split_sizes, int[] input_split_sizes, int timeout) -> __torch__.torch.classes.c10d.Work");
45 m.def(
46 "barrier(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] device_ids, int timeout) -> __torch__.torch.classes.c10d.Work");
47 m.def(
48 "monitored_barrier_(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] device_ids, int timeout, bool wait_all_ranks) -> ()");
49 m.def(
50 "send(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int dst, int tag) -> __torch__.torch.classes.c10d.Work");
51 m.def(
52 "recv_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int src, int tag) -> __torch__.torch.classes.c10d.Work");
53 m.def(
54 "recv_any_source_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int tag) -> __torch__.torch.classes.c10d.Work");
55}
56} // namespace
57
58namespace ops {
59
60// Below are ProcessGroup's corresponding ops for each backend. Ops are but
61// routed through the dispatcher to be dispatched to the appropriate backend.
62// Currently a no-op as the process group does not have a list of backends.
63c10::intrusive_ptr<Work> send_cpu(
64 at::TensorList tensors,
65 const c10::intrusive_ptr<ProcessGroup>& process_group,
66 int64_t dstRank,
67 int64_t tag) {
68 auto tensor_vec = tensors.vec();
69 return process_group->getBackend(c10::DeviceType::CPU)
70 ->send(tensor_vec, static_cast<int>(dstRank), static_cast<int>(tag));
71}
72
73c10::intrusive_ptr<Work> send_cuda(
74 at::TensorList tensors,
75 const c10::intrusive_ptr<ProcessGroup>& process_group,
76 int64_t dstRank,
77 int64_t tag) {
78 auto tensor_vec = tensors.vec();
79 return process_group->getBackend(c10::DeviceType::CUDA)
80 ->send(tensor_vec, static_cast<int>(dstRank), static_cast<int>(tag));
81}
82
83c10::intrusive_ptr<Work> recv_cpu_(
84 at::TensorList tensors,
85 const c10::intrusive_ptr<ProcessGroup>& process_group,
86 int64_t srcRank,
87 int64_t tag) {
88 auto tensor_vec = tensors.vec();
89 return process_group->getBackend(c10::DeviceType::CPU)
90 ->recv(tensor_vec, static_cast<int>(srcRank), static_cast<int>(tag));
91}
92
93c10::intrusive_ptr<Work> recv_cuda_(
94 at::TensorList tensors,
95 const c10::intrusive_ptr<ProcessGroup>& process_group,
96 int64_t srcRank,
97 int64_t tag) {
98 auto tensor_vec = tensors.vec();
99 return process_group->getBackend(c10::DeviceType::CUDA)
100 ->recv(tensor_vec, static_cast<int>(srcRank), static_cast<int>(tag));
101}
102
103c10::intrusive_ptr<Work> recv_any_source_cpu_(
104 at::TensorList tensors,
105 const c10::intrusive_ptr<ProcessGroup>& process_group,
106 int64_t tag) {
107 auto tensor_vec = tensors.vec();
108 return process_group->getBackend(c10::DeviceType::CPU)
109 ->recvAnysource(tensor_vec, static_cast<int>(tag));
110}
111
112c10::intrusive_ptr<Work> recv_any_source_cuda_(
113 at::TensorList tensors,
114 const c10::intrusive_ptr<ProcessGroup>& process_group,
115 int64_t tag) {
116 auto tensor_vec = tensors.vec();
117 return process_group->getBackend(c10::DeviceType::CUDA)
118 ->recvAnysource(tensor_vec, static_cast<int>(tag));
119}
120
121c10::intrusive_ptr<Work> reduce_cpu_(
122 at::TensorList tensors,
123 const c10::intrusive_ptr<ProcessGroup>& process_group,
124 const c10::intrusive_ptr<ReduceOp>& reduce_op,
125 int64_t root_rank,
126 int64_t root_tensor,
127 int64_t timeout) {
128 auto tensor_vec = tensors.vec();
129 return process_group->getBackend(c10::DeviceType::CPU)
130 ->reduce(
131 tensor_vec,
132 ReduceOptions{
133 *reduce_op.get(),
134 root_rank,
135 root_tensor,
136 std::chrono::milliseconds(timeout)});
137}
138
139c10::intrusive_ptr<Work> reduce_cuda_(
140 at::TensorList tensors,
141 const c10::intrusive_ptr<ProcessGroup>& process_group,
142 const c10::intrusive_ptr<ReduceOp>& reduce_op,
143 int64_t root_rank,
144 int64_t root_tensor,
145 int64_t timeout) {
146 auto tensor_vec = tensors.vec();
147 return process_group->getBackend(c10::DeviceType::CUDA)
148 ->reduce(
149 tensor_vec,
150 ReduceOptions{
151 *reduce_op.get(),
152 root_rank,
153 root_tensor,
154 std::chrono::milliseconds(timeout)});
155}
156
157std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> broadcast_cpu_(
158 at::TensorList tensors,
159 const c10::intrusive_ptr<ProcessGroup>& process_group,
160 int64_t root_rank,
161 int64_t root_tensor,
162 int64_t timeout) {
163 auto tensor_vec = tensors.vec();
164 auto work =
165 process_group->getBackend(c10::DeviceType::CPU)
166 ->broadcast(
167 tensor_vec,
168 BroadcastOptions{
169 root_rank, root_tensor, std::chrono::milliseconds(timeout)});
170
171 return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(
172 std::move(tensor_vec), work);
173}
174
175std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> broadcast_cuda_(
176 at::TensorList tensors,
177 const c10::intrusive_ptr<ProcessGroup>& process_group,
178 int64_t root_rank,
179 int64_t root_tensor,
180 int64_t timeout) {
181 auto tensor_vec = tensors.vec();
182 auto work =
183 process_group->getBackend(c10::DeviceType::CUDA)
184 ->broadcast(
185 tensor_vec,
186 BroadcastOptions{
187 root_rank, root_tensor, std::chrono::milliseconds(timeout)});
188
189 return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(
190 std::move(tensor_vec), work);
191}
192
193std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> allreduce_cpu_(
194 at::TensorList tensors,
195 const c10::intrusive_ptr<ProcessGroup>& process_group,
196 const c10::intrusive_ptr<ReduceOp>& reduce_op,
197 int64_t timeout) {
198 auto tensor_vec = tensors.vec();
199 auto work =
200 process_group->getBackend(c10::DeviceType::CPU)
201 ->allreduce(
202 tensor_vec,
203 AllreduceOptions{
204 *reduce_op.get(), std::chrono::milliseconds(timeout)});
205
206 // Return input tensors as output tensors to make inplace allreduce look like
207 // a functional API, so that make_fx can correctly build the dependencies in
208 // the graph later.
209 return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(
210 std::move(tensor_vec), work);
211}
212
213std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> allreduce_cuda_(
214 at::TensorList tensors,
215 const c10::intrusive_ptr<ProcessGroup>& process_group,
216 const c10::intrusive_ptr<ReduceOp>& reduce_op,
217 int64_t timeout) {
218 auto tensor_vec = tensors.vec();
219 auto work =
220 process_group->getBackend(c10::DeviceType::CUDA)
221 ->allreduce(
222 tensor_vec,
223 AllreduceOptions{
224 *reduce_op.get(), std::chrono::milliseconds(timeout)});
225
226 // Return input tensors as output tensors to make inplace allreduce look like
227 // a functional API, so that make_fx can correctly build the dependencies in
228 // the graph later.
229 return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(
230 std::move(tensor_vec), work);
231}
232
233c10::intrusive_ptr<Work> allreduce_coalesced_cpu_(
234 at::TensorList tensors,
235 const c10::intrusive_ptr<ProcessGroup>& process_group,
236 const c10::intrusive_ptr<ReduceOp>& reduce_op,
237 int64_t timeout) {
238 auto tensor_vec = tensors.vec();
239 AllreduceCoalescedOptions opts = AllreduceCoalescedOptions{};
240 opts.reduceOp = *reduce_op.get();
241 opts.timeout = std::chrono::milliseconds(timeout);
242
243 return process_group->getBackend(c10::DeviceType::CPU)
244 ->allreduce_coalesced(tensor_vec, opts);
245}
246
247c10::intrusive_ptr<Work> allreduce_coalesced_cuda_(
248 at::TensorList tensors,
249 const c10::intrusive_ptr<ProcessGroup>& process_group,
250 const c10::intrusive_ptr<ReduceOp>& reduce_op,
251 int64_t timeout) {
252 auto tensor_vec = tensors.vec();
253 AllreduceCoalescedOptions opts = AllreduceCoalescedOptions{};
254 opts.reduceOp = *reduce_op.get();
255 opts.timeout = std::chrono::milliseconds(timeout);
256
257 return process_group->getBackend(c10::DeviceType::CUDA)
258 ->allreduce_coalesced(tensor_vec, opts);
259}
260
261std::tuple<std::vector<std::vector<at::Tensor>>, c10::intrusive_ptr<Work>>
262allgather_cpu_(
263 const std::vector<std::vector<at::Tensor>>& output_tensors,
264 at::TensorList input_tensors,
265 const c10::intrusive_ptr<ProcessGroup>& process_group,
266 int64_t timeout) {
267 auto input_tensors_vec = input_tensors.vec();
268 auto work =
269 process_group->getBackend(c10::DeviceType::CPU)
270 ->allgather(
271 const_cast<std::vector<std::vector<at::Tensor>>&>(output_tensors),
272 input_tensors_vec,
273 AllgatherOptions{std::chrono::milliseconds(timeout)});
274
275 // Copy output tensors (not storage) so that this can be used in a functional
276 // manner
277 return std::
278 tuple<std::vector<std::vector<at::Tensor>>, c10::intrusive_ptr<Work>>(
279 output_tensors, work);
280}
281
282std::tuple<std::vector<std::vector<at::Tensor>>, c10::intrusive_ptr<Work>>
283allgather_cuda_(
284 const std::vector<std::vector<at::Tensor>>& output_tensors,
285 at::TensorList input_tensors,
286 const c10::intrusive_ptr<ProcessGroup>& process_group,
287 int64_t timeout) {
288 auto input_tensors_vec = input_tensors.vec();
289 auto work =
290 process_group->getBackend(c10::DeviceType::CUDA)
291 ->allgather(
292 const_cast<std::vector<std::vector<at::Tensor>>&>(output_tensors),
293 input_tensors_vec,
294 AllgatherOptions{std::chrono::milliseconds(timeout)});
295
296 // Copy output tensors (not storage) so that this can be used in a functional
297 // manner
298 return std::
299 tuple<std::vector<std::vector<at::Tensor>>, c10::intrusive_ptr<Work>>(
300 output_tensors, work);
301}
302
303std::tuple<at::Tensor, c10::intrusive_ptr<Work>> _allgather_base_cpu_(
304 at::Tensor& output_tensor,
305 at::Tensor& input_tensor,
306 const c10::intrusive_ptr<ProcessGroup>& process_group) {
307 auto work = process_group->getBackend(c10::DeviceType::CPU)
308 ->_allgather_base(output_tensor, input_tensor);
309
310 return std::tuple<at::Tensor, c10::intrusive_ptr<Work>>(output_tensor, work);
311}
312
313std::tuple<at::Tensor, c10::intrusive_ptr<Work>> _allgather_base_cuda_(
314 at::Tensor& output_tensor,
315 at::Tensor& input_tensor,
316 const c10::intrusive_ptr<ProcessGroup>& process_group) {
317 auto work = process_group->getBackend(c10::DeviceType::CUDA)
318 ->_allgather_base(output_tensor, input_tensor);
319
320 return std::tuple<at::Tensor, c10::intrusive_ptr<Work>>(output_tensor, work);
321}
322
323c10::intrusive_ptr<Work> allgather_coalesced_cpu_(
324 const std::vector<std::vector<at::Tensor>>& output_lists,
325 const at::TensorList& input_list,
326 const c10::intrusive_ptr<ProcessGroup>& process_group) {
327 auto input_list_vec = input_list.vec();
328 return process_group->getBackend(c10::DeviceType::CPU)
329 ->allgather_coalesced(
330 const_cast<std::vector<std::vector<at::Tensor>>&>(output_lists),
331 input_list_vec);
332}
333
334c10::intrusive_ptr<Work> allgather_coalesced_cuda_(
335 const std::vector<std::vector<at::Tensor>>& output_lists,
336 const at::TensorList& input_list,
337 const c10::intrusive_ptr<ProcessGroup>& process_group) {
338 auto input_list_vec = input_list.vec();
339 return process_group->getBackend(c10::DeviceType::CUDA)
340 ->allgather_coalesced(
341 const_cast<std::vector<std::vector<at::Tensor>>&>(output_lists),
342 input_list_vec);
343}
344
345std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>
346reduce_scatter_cpu_(
347 const at::TensorList& output_tensors,
348 const std::vector<std::vector<at::Tensor>>& input_tensors,
349 const c10::intrusive_ptr<ProcessGroup>& process_group,
350 const c10::intrusive_ptr<ReduceOp>& reduce_op,
351 int64_t timeout) {
352 auto output_tensors_vec = output_tensors.vec();
353 auto work =
354 process_group->getBackend(c10::DeviceType::CPU)
355 ->reduce_scatter(
356 output_tensors_vec,
357 const_cast<std::vector<std::vector<at::Tensor>>&>(input_tensors),
358 ReduceScatterOptions{
359 *reduce_op.get(), std::chrono::milliseconds(timeout)});
360
361 return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(
362 output_tensors_vec, work);
363}
364
365std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>
366reduce_scatter_cuda_(
367 const at::TensorList& output_tensors,
368 const std::vector<std::vector<at::Tensor>>& input_tensors,
369 const c10::intrusive_ptr<ProcessGroup>& process_group,
370 const c10::intrusive_ptr<ReduceOp>& reduce_op,
371 int64_t timeout) {
372 auto output_tensors_vec = output_tensors.vec();
373 auto work =
374 process_group->getBackend(c10::DeviceType::CUDA)
375 ->reduce_scatter(
376 output_tensors_vec,
377 const_cast<std::vector<std::vector<at::Tensor>>&>(input_tensors),
378 ReduceScatterOptions{
379 *reduce_op.get(), std::chrono::milliseconds(timeout)});
380
381 return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(
382 output_tensors_vec, work);
383}
384
385std::tuple<at::Tensor, c10::intrusive_ptr<Work>> _reduce_scatter_base_cpu_(
386 at::Tensor& output_tensor,
387 at::Tensor& input_tensor,
388 const c10::intrusive_ptr<ProcessGroup>& process_group,
389 const c10::intrusive_ptr<ReduceOp>& reduce_op,
390 int64_t timeout) {
391 auto work =
392 process_group->getBackend(c10::DeviceType::CPU)
393 ->_reduce_scatter_base(
394 output_tensor,
395 input_tensor,
396 ReduceScatterOptions{
397 *reduce_op.get(), std::chrono::milliseconds(timeout)});
398
399 return std::tuple<at::Tensor, c10::intrusive_ptr<Work>>(output_tensor, work);
400}
401
402std::tuple<at::Tensor, c10::intrusive_ptr<Work>> _reduce_scatter_base_cuda_(
403 at::Tensor& output_tensor,
404 at::Tensor& input_tensor,
405 const c10::intrusive_ptr<ProcessGroup>& process_group,
406 const c10::intrusive_ptr<ReduceOp>& reduce_op,
407 int64_t timeout) {
408 auto work =
409 process_group->getBackend(c10::DeviceType::CUDA)
410 ->_reduce_scatter_base(
411 output_tensor,
412 input_tensor,
413 ReduceScatterOptions{
414 *reduce_op.get(), std::chrono::milliseconds(timeout)});
415
416 return std::tuple<at::Tensor, c10::intrusive_ptr<Work>>(output_tensor, work);
417}
418
419c10::intrusive_ptr<Work> gather_cpu_(
420 const std::vector<std::vector<at::Tensor>>& output_tensors,
421 const at::TensorList& input_tensors,
422 const c10::intrusive_ptr<ProcessGroup>& process_group,
423 int64_t root_rank,
424 int64_t timeout) {
425 auto input_tensors_vec = input_tensors.vec();
426 return process_group->getBackend(c10::DeviceType::CPU)
427 ->gather(
428 const_cast<std::vector<std::vector<at::Tensor>>&>(output_tensors),
429 input_tensors_vec,
430 GatherOptions{root_rank, std::chrono::milliseconds(timeout)});
431}
432c10::intrusive_ptr<Work> gather_cuda_(
433 const std::vector<std::vector<at::Tensor>>& output_tensors,
434 const at::TensorList& input_tensors,
435 const c10::intrusive_ptr<ProcessGroup>& process_group,
436 int64_t root_rank,
437 int64_t timeout) {
438 auto input_tensors_vec = input_tensors.vec();
439 return process_group->getBackend(c10::DeviceType::CUDA)
440 ->gather(
441 const_cast<std::vector<std::vector<at::Tensor>>&>(output_tensors),
442 input_tensors_vec,
443 GatherOptions{root_rank, std::chrono::milliseconds(timeout)});
444}
445
446std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> scatter_cpu_(
447 const at::TensorList& output_tensors,
448 const std::vector<std::vector<at::Tensor>>& input_tensors,
449 const c10::intrusive_ptr<ProcessGroup>& process_group,
450 int64_t root_rank,
451 int64_t timeout) {
452 auto output_tensors_vec = output_tensors.vec();
453 auto work =
454 process_group->getBackend(c10::DeviceType::CPU)
455 ->scatter(
456 output_tensors_vec,
457 const_cast<std::vector<std::vector<at::Tensor>>&>(input_tensors),
458 ScatterOptions{root_rank, std::chrono::milliseconds(timeout)});
459
460 return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(
461 std::move(output_tensors_vec), work);
462}
463
464std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> scatter_cuda_(
465 const at::TensorList& output_tensors,
466 const std::vector<std::vector<at::Tensor>>& input_tensors,
467 const c10::intrusive_ptr<ProcessGroup>& process_group,
468 int64_t root_rank,
469 int64_t timeout) {
470 auto output_tensors_vec = output_tensors.vec();
471 auto work =
472 process_group->getBackend(c10::DeviceType::CUDA)
473 ->scatter(
474 output_tensors_vec,
475 const_cast<std::vector<std::vector<at::Tensor>>&>(input_tensors),
476 ScatterOptions{root_rank, std::chrono::milliseconds(timeout)});
477
478 return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(
479 std::move(output_tensors_vec), work);
480}
481
482std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> alltoall_cpu_(
483 const at::TensorList& output_tensors,
484 const at::TensorList& input_tensors,
485 const c10::intrusive_ptr<ProcessGroup>& process_group,
486 int64_t timeout) {
487 auto output_tensors_vec = output_tensors.vec();
488 auto input_tensors_vec = input_tensors.vec();
489 auto work = process_group->getBackend(c10::DeviceType::CPU)
490 ->alltoall(
491 output_tensors_vec,
492 input_tensors_vec,
493 AllToAllOptions{std::chrono::milliseconds(timeout)});
494 return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(
495 std::move(output_tensors_vec), work);
496}
497
498std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> alltoall_cuda_(
499 const at::TensorList& output_tensors,
500 const at::TensorList& input_tensors,
501 const c10::intrusive_ptr<ProcessGroup>& process_group,
502 int64_t timeout) {
503 auto output_tensors_vec = output_tensors.vec();
504 auto input_tensors_vec = input_tensors.vec();
505 auto work = process_group->getBackend(c10::DeviceType::CUDA)
506 ->alltoall(
507 output_tensors_vec,
508 input_tensors_vec,
509 AllToAllOptions{std::chrono::milliseconds(timeout)});
510 return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(
511 std::move(output_tensors_vec), work);
512}
513
514c10::intrusive_ptr<Work> alltoall_base_cpu_(
515 at::Tensor& output,
516 at::Tensor& input,
517 const c10::intrusive_ptr<ProcessGroup>& process_group,
518 std::vector<int64_t> output_split_sizes,
519 std::vector<int64_t> input_split_sizes,
520 int64_t timeout) {
521 return process_group->getBackend(c10::DeviceType::CPU)
522 ->alltoall_base(
523 output,
524 input,
525 output_split_sizes,
526 input_split_sizes,
527 AllToAllOptions{std::chrono::milliseconds(timeout)});
528}
529
530c10::intrusive_ptr<Work> alltoall_base_cuda_(
531 at::Tensor& output,
532 at::Tensor& input,
533 const c10::intrusive_ptr<ProcessGroup>& process_group,
534 std::vector<int64_t> output_split_sizes,
535 std::vector<int64_t> input_split_sizes,
536 int64_t timeout) {
537 return process_group->getBackend(c10::DeviceType::CUDA)
538 ->alltoall_base(
539 output,
540 input,
541 output_split_sizes,
542 input_split_sizes,
543 AllToAllOptions{std::chrono::milliseconds(timeout)});
544}
545
546c10::intrusive_ptr<Work> barrier_cpu(
547 at::Tensor /* unused */,
548 const c10::intrusive_ptr<ProcessGroup>& process_group,
549 const std::vector<int64_t>& device_ids,
550 int64_t timeout) {
551 return process_group->getBackend(c10::DeviceType::CPU)
552 ->barrier(BarrierOptions{device_ids, std::chrono::milliseconds(timeout)});
553}
554
555c10::intrusive_ptr<Work> barrier_cuda(
556 at::Tensor /* unused */,
557 const c10::intrusive_ptr<ProcessGroup>& process_group,
558 const std::vector<int64_t>& device_ids,
559 int64_t timeout) {
560 return process_group->getBackend(c10::DeviceType::CUDA)
561 ->barrier(BarrierOptions{device_ids, std::chrono::milliseconds(timeout)});
562}
563
564void monitored_barrier_cpu_(
565 at::Tensor /* unused */,
566 const c10::intrusive_ptr<::c10d::ProcessGroup>& process_group,
567 const std::vector<int64_t>& device_ids,
568 int64_t timeout,
569 bool wait_all_ranks) {
570 process_group->getBackend(c10::DeviceType::CPU)
571 ->monitoredBarrier(
572 BarrierOptions{device_ids, std::chrono::milliseconds(timeout)},
573 wait_all_ranks);
574}
575
576// register functions to dispatcher
577namespace {
578TORCH_LIBRARY_IMPL(c10d, CPU, m) {
579 m.impl("send", send_cpu);
580}
581
582TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
583 m.impl("send", send_cuda);
584}
585
586TORCH_LIBRARY_IMPL(c10d, CPU, m) {
587 m.impl("recv_", recv_cpu_);
588}
589
590TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
591 m.impl("recv_", recv_cuda_);
592}
593
594TORCH_LIBRARY_IMPL(c10d, CPU, m) {
595 m.impl("recv_any_source_", recv_any_source_cpu_);
596}
597
598TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
599 m.impl("recv_any_source_", recv_any_source_cuda_);
600}
601
602TORCH_LIBRARY_IMPL(c10d, CPU, m) {
603 m.impl("reduce_", reduce_cpu_);
604}
605
606TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
607 m.impl("reduce_", reduce_cuda_);
608}
609
610TORCH_LIBRARY_IMPL(c10d, CPU, m) {
611 m.impl("broadcast_", broadcast_cpu_);
612}
613
614TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
615 m.impl("broadcast_", broadcast_cuda_);
616}
617
618TORCH_LIBRARY_IMPL(c10d, CPU, m) {
619 m.impl("allreduce_", allreduce_cpu_);
620}
621
622// TODO: The SparseCPU/SparseCUDA dispatched methods are only used to support
623// sparse all_reduce in the Gloo backend
624TORCH_LIBRARY_IMPL(c10d, SparseCPU, m) {
625 m.impl("allreduce_", allreduce_cpu_);
626}
627
628TORCH_LIBRARY_IMPL(c10d, SparseCUDA, m) {
629 m.impl("allreduce_", allreduce_cuda_);
630}
631
632TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
633 m.impl("allreduce_", allreduce_cuda_);
634}
635
636TORCH_LIBRARY_IMPL(c10d, CPU, m) {
637 m.impl("allreduce_coalesced_", allreduce_coalesced_cpu_);
638}
639
640TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
641 m.impl("allreduce_coalesced_", allreduce_coalesced_cuda_);
642}
643
644TORCH_LIBRARY_IMPL(c10d, CPU, m) {
645 m.impl("allgather_", allgather_cpu_);
646}
647
648TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
649 m.impl("allgather_", allgather_cuda_);
650}
651
652TORCH_LIBRARY_IMPL(c10d, CPU, m) {
653 m.impl("_allgather_base_", _allgather_base_cpu_);
654}
655
656TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
657 m.impl("_allgather_base_", _allgather_base_cuda_);
658}
659
660TORCH_LIBRARY_IMPL(c10d, CPU, m) {
661 m.impl("allgather_coalesced_", allgather_coalesced_cpu_);
662}
663
664TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
665 m.impl("allgather_coalesced_", allgather_coalesced_cuda_);
666}
667
668TORCH_LIBRARY_IMPL(c10d, CPU, m) {
669 m.impl("reduce_scatter_", reduce_scatter_cpu_);
670}
671
672TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
673 m.impl("reduce_scatter_", reduce_scatter_cuda_);
674}
675
676TORCH_LIBRARY_IMPL(c10d, CPU, m) {
677 m.impl("_reduce_scatter_base_", _reduce_scatter_base_cpu_);
678}
679
680TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
681 m.impl("_reduce_scatter_base_", _reduce_scatter_base_cuda_);
682}
683
684TORCH_LIBRARY_IMPL(c10d, CPU, m) {
685 m.impl("gather_", gather_cpu_);
686}
687
688TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
689 m.impl("gather_", gather_cuda_);
690}
691
692TORCH_LIBRARY_IMPL(c10d, CPU, m) {
693 m.impl("scatter_", scatter_cpu_);
694}
695
696TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
697 m.impl("scatter_", scatter_cuda_);
698}
699
700TORCH_LIBRARY_IMPL(c10d, CPU, m) {
701 m.impl("alltoall_", alltoall_cpu_);
702}
703
704TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
705 m.impl("alltoall_", alltoall_cuda_);
706}
707
708TORCH_LIBRARY_IMPL(c10d, CPU, m) {
709 m.impl("alltoall_base_", alltoall_base_cpu_);
710}
711
712TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
713 m.impl("alltoall_base_", alltoall_base_cuda_);
714}
715
716TORCH_LIBRARY_IMPL(c10d, CPU, m) {
717 m.impl("barrier", barrier_cpu);
718}
719
720TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
721 m.impl("barrier", barrier_cuda);
722}
723
724TORCH_LIBRARY_IMPL(c10d, CPU, m) {
725 m.impl("monitored_barrier_", monitored_barrier_cpu_);
726}
727
728} // namespace
729
730} // namespace ops
731} // namespace c10d
732