1#include <ATen/core/functional.h>
2#include <pybind11/pybind11.h>
3#include <torch/csrc/cuda/Stream.h>
4#include <torch/csrc/cuda/THCP.h>
5#include <torch/csrc/cuda/comm.h>
6#include <torch/csrc/utils/pybind.h>
7
8#include <ATen/ATen.h>
9
10#include <cstddef>
11#include <vector>
12
13namespace torch {
14namespace cuda {
15namespace python {
16void initCommMethods(PyObject* module) {
17 auto m = py::cast<py::module>(module);
18 m.def(
19 "_broadcast_coalesced",
20 [](std::vector<at::Tensor>& tensors,
21 std::vector<int64_t> devices,
22 size_t buffer_size) {
23 return broadcast_coalesced(tensors, devices, buffer_size);
24 },
25 py::arg("tensors"),
26 py::arg("devices"),
27 py::arg("buffer_size"),
28 py::call_guard<py::gil_scoped_release>())
29 .def(
30 "_broadcast",
31 [](at::Tensor& tensor, std::vector<int64_t> devices) {
32 return broadcast(tensor, devices);
33 },
34 py::call_guard<py::gil_scoped_release>(),
35 py::arg("tensor"),
36 py::arg("devices"))
37 .def(
38 "_broadcast_out",
39 [](at::Tensor& tensor, std::vector<at::Tensor>& out_tensors) {
40 return broadcast_out(tensor, out_tensors);
41 },
42 py::call_guard<py::gil_scoped_release>(),
43 py::arg("tensor"),
44 py::arg("out"))
45 .def(
46 "_scatter",
47 [](at::Tensor& tensor,
48 std::vector<int64_t>& devices,
49 c10::optional<std::vector<int64_t>> chunk_sizes,
50 int64_t dim,
51 c10::optional<py::object> py_streams) {
52 c10::optional<std::vector<c10::optional<at::cuda::CUDAStream>>>
53 streams;
54 if (py_streams) {
55 py::handle handle = *py_streams;
56 streams = THPUtils_PySequence_to_CUDAStreamList(handle.ptr());
57 }
58 // Note: We're holding the GIL up to here.
59 pybind11::gil_scoped_release no_gil;
60 return scatter(tensor, devices, chunk_sizes, dim, streams);
61 },
62 py::arg("tensor"),
63 py::arg("devices"),
64 py::arg("chunk_sizes"),
65 py::arg("dim"),
66 py::arg("streams"))
67 .def(
68 "_scatter_out",
69 [](at::Tensor& tensor,
70 std::vector<at::Tensor>& out_tensors,
71 int64_t dim,
72 c10::optional<py::object> py_streams) {
73 c10::optional<std::vector<c10::optional<at::cuda::CUDAStream>>>
74 streams;
75 if (py_streams) {
76 py::handle handle = *py_streams;
77 streams = THPUtils_PySequence_to_CUDAStreamList(handle.ptr());
78 }
79 // Note: We're holding the GIL up to here.
80 pybind11::gil_scoped_release no_gil;
81 return scatter_out(tensor, out_tensors, dim, streams);
82 },
83 py::arg("tensor"),
84 py::arg("out"),
85 py::arg("dim"),
86 py::arg("streams"))
87 .def(
88 "_gather",
89 [](std::vector<at::Tensor>& tensors,
90 int64_t dim,
91 c10::optional<int32_t> destination_index) {
92 return gather(tensors, dim, destination_index);
93 },
94 py::arg("tensors"),
95 py::arg("dim"),
96 py::arg("destination_index"),
97 py::call_guard<py::gil_scoped_release>())
98 .def(
99 "_gather_out",
100 [](std::vector<at::Tensor>& tensors,
101 at::Tensor& out_tensor,
102 int64_t dim) { return gather_out(tensors, out_tensor, dim); },
103 py::arg("tensors"),
104 py::arg("out"),
105 py::arg("dim"),
106 py::call_guard<py::gil_scoped_release>());
107}
108} // namespace python
109} // namespace cuda
110} // namespace torch
111