1 | #include <ATen/core/functional.h> |
2 | #include <pybind11/pybind11.h> |
3 | #include <torch/csrc/cuda/Stream.h> |
4 | #include <torch/csrc/cuda/THCP.h> |
5 | #include <torch/csrc/cuda/comm.h> |
6 | #include <torch/csrc/utils/pybind.h> |
7 | |
8 | #include <ATen/ATen.h> |
9 | |
10 | #include <cstddef> |
11 | #include <vector> |
12 | |
13 | namespace torch { |
14 | namespace cuda { |
15 | namespace python { |
16 | void initCommMethods(PyObject* module) { |
17 | auto m = py::cast<py::module>(module); |
18 | m.def( |
19 | "_broadcast_coalesced" , |
20 | [](std::vector<at::Tensor>& tensors, |
21 | std::vector<int64_t> devices, |
22 | size_t buffer_size) { |
23 | return broadcast_coalesced(tensors, devices, buffer_size); |
24 | }, |
25 | py::arg("tensors" ), |
26 | py::arg("devices" ), |
27 | py::arg("buffer_size" ), |
28 | py::call_guard<py::gil_scoped_release>()) |
29 | .def( |
30 | "_broadcast" , |
31 | [](at::Tensor& tensor, std::vector<int64_t> devices) { |
32 | return broadcast(tensor, devices); |
33 | }, |
34 | py::call_guard<py::gil_scoped_release>(), |
35 | py::arg("tensor" ), |
36 | py::arg("devices" )) |
37 | .def( |
38 | "_broadcast_out" , |
39 | [](at::Tensor& tensor, std::vector<at::Tensor>& out_tensors) { |
40 | return broadcast_out(tensor, out_tensors); |
41 | }, |
42 | py::call_guard<py::gil_scoped_release>(), |
43 | py::arg("tensor" ), |
44 | py::arg("out" )) |
45 | .def( |
46 | "_scatter" , |
47 | [](at::Tensor& tensor, |
48 | std::vector<int64_t>& devices, |
49 | c10::optional<std::vector<int64_t>> chunk_sizes, |
50 | int64_t dim, |
51 | c10::optional<py::object> py_streams) { |
52 | c10::optional<std::vector<c10::optional<at::cuda::CUDAStream>>> |
53 | streams; |
54 | if (py_streams) { |
55 | py::handle handle = *py_streams; |
56 | streams = THPUtils_PySequence_to_CUDAStreamList(handle.ptr()); |
57 | } |
58 | // Note: We're holding the GIL up to here. |
59 | pybind11::gil_scoped_release no_gil; |
60 | return scatter(tensor, devices, chunk_sizes, dim, streams); |
61 | }, |
62 | py::arg("tensor" ), |
63 | py::arg("devices" ), |
64 | py::arg("chunk_sizes" ), |
65 | py::arg("dim" ), |
66 | py::arg("streams" )) |
67 | .def( |
68 | "_scatter_out" , |
69 | [](at::Tensor& tensor, |
70 | std::vector<at::Tensor>& out_tensors, |
71 | int64_t dim, |
72 | c10::optional<py::object> py_streams) { |
73 | c10::optional<std::vector<c10::optional<at::cuda::CUDAStream>>> |
74 | streams; |
75 | if (py_streams) { |
76 | py::handle handle = *py_streams; |
77 | streams = THPUtils_PySequence_to_CUDAStreamList(handle.ptr()); |
78 | } |
79 | // Note: We're holding the GIL up to here. |
80 | pybind11::gil_scoped_release no_gil; |
81 | return scatter_out(tensor, out_tensors, dim, streams); |
82 | }, |
83 | py::arg("tensor" ), |
84 | py::arg("out" ), |
85 | py::arg("dim" ), |
86 | py::arg("streams" )) |
87 | .def( |
88 | "_gather" , |
89 | [](std::vector<at::Tensor>& tensors, |
90 | int64_t dim, |
91 | c10::optional<int32_t> destination_index) { |
92 | return gather(tensors, dim, destination_index); |
93 | }, |
94 | py::arg("tensors" ), |
95 | py::arg("dim" ), |
96 | py::arg("destination_index" ), |
97 | py::call_guard<py::gil_scoped_release>()) |
98 | .def( |
99 | "_gather_out" , |
100 | [](std::vector<at::Tensor>& tensors, |
101 | at::Tensor& out_tensor, |
102 | int64_t dim) { return gather_out(tensors, out_tensor, dim); }, |
103 | py::arg("tensors" ), |
104 | py::arg("out" ), |
105 | py::arg("dim" ), |
106 | py::call_guard<py::gil_scoped_release>()); |
107 | } |
108 | } // namespace python |
109 | } // namespace cuda |
110 | } // namespace torch |
111 | |