1 | #include <torch/csrc/cuda/python_nccl.h> |
2 | |
3 | #include <ATen/core/functional.h> |
4 | #include <pybind11/pybind11.h> |
5 | #include <torch/csrc/DynamicTypes.h> |
6 | #include <torch/csrc/Exceptions.h> |
7 | #include <torch/csrc/THP.h> |
8 | #include <torch/csrc/Types.h> |
9 | #include <torch/csrc/cuda/THCP.h> |
10 | #include <torch/csrc/cuda/nccl.h> |
11 | #include <torch/csrc/utils/pybind.h> |
12 | |
13 | #include <c10/cuda/CUDAGuard.h> |
14 | #include <c10/util/irange.h> |
15 | |
16 | #include <sstream> |
17 | #include <unordered_map> |
18 | |
19 | using namespace at; |
20 | using namespace torch; |
21 | using namespace torch::cuda::nccl; |
22 | using namespace torch::cuda::nccl::detail; |
23 | |
24 | static const char* COMM_CAPSULE_NAME = "torch.cuda.nccl.Communicator" ; |
25 | |
26 | PyObject* THCPModule_nccl_version(PyObject* self, PyObject* args) { |
27 | return PyInt_FromLong(version()); |
28 | } |
29 | |
30 | PyObject* THCPModule_nccl_unique_id(PyObject* self, PyObject* args) { |
31 | HANDLE_TH_ERRORS |
32 | ncclUniqueId id; |
33 | get_unique_id(id); |
34 | return PyBytes_FromStringAndSize((char*)&id, NCCL_UNIQUE_ID_BYTES); |
35 | END_HANDLE_TH_ERRORS |
36 | } |
37 | |
38 | static ncclComm_t unpack_nccl_comm(PyObject* capsule) { |
39 | ncclComm_t comm = |
40 | (ncclComm_t)PyCapsule_GetPointer(capsule, COMM_CAPSULE_NAME); |
41 | if (!comm) |
42 | throw python_error(); |
43 | return comm; |
44 | } |
45 | |
46 | static void destroy_nccl_comm(PyObject* capsule) { |
47 | HANDLE_TH_ERRORS |
48 | ncclComm_t comm = unpack_nccl_comm(capsule); |
49 | { |
50 | pybind11::gil_scoped_release no_gil; |
51 | comm_destroy(comm); |
52 | } |
53 | END_HANDLE_TH_ERRORS_RET() |
54 | } |
55 | |
56 | static std::vector<c10::optional<at::cuda::CUDAStream>> unpack_streams( |
57 | PyObject* obj, |
58 | size_t size) { |
59 | if (obj == Py_None) { |
60 | return std::vector<c10::optional<at::cuda::CUDAStream>>(size, c10::nullopt); |
61 | } |
62 | auto streams = THPUtils_PySequence_to_CUDAStreamList(obj); |
63 | if (streams.size() != size) { |
64 | throw std::runtime_error( |
65 | "number of streams is not equal to number of inputs" ); |
66 | } |
67 | return streams; |
68 | } |
69 | |
70 | static inline at::Tensor extract_tensor(PyObject* obj); |
71 | static inline std::vector<at::Tensor> extract_tensors(PyObject* obj); |
72 | |
73 | static std::vector<ncclComm_t> unpack_comms(PyObject* obj, size_t size) { |
74 | if (obj == Py_None) { |
75 | return std::vector<ncclComm_t>(); |
76 | } |
77 | std::vector<ncclComm_t> comms; |
78 | if (PyCapsule_CheckExact(obj)) { |
79 | comms = {unpack_nccl_comm(obj)}; |
80 | } else { |
81 | auto seq = THPObjectPtr(PySequence_Fast(obj, "comm is not a sequence" )); |
82 | if (!seq) |
83 | throw python_error(); |
84 | auto size = PySequence_Fast_GET_SIZE(seq.get()); |
85 | comms = std::vector<ncclComm_t>(size); |
86 | for (const auto i : c10::irange(size)) { |
87 | comms[i] = unpack_nccl_comm(PySequence_Fast_GET_ITEM(seq.get(), i)); |
88 | } |
89 | } |
90 | if (comms.size() != size) { |
91 | throw std::runtime_error( |
92 | "number of communicators is not equal to number of inputs" ); |
93 | } |
94 | return comms; |
95 | } |
96 | |
97 | PyObject* THCPModule_nccl_init_rank(PyObject* self, PyObject* args) { |
98 | HANDLE_TH_ERRORS |
99 | int nranks; |
100 | const char* id; |
101 | Py_ssize_t id_len; |
102 | int rank; |
103 | |
104 | if (!PyArg_ParseTuple( |
105 | args, "is#i:nccl_init_rank" , &nranks, &id, &id_len, &rank)) { |
106 | return nullptr; |
107 | } |
108 | THPUtils_assert( |
109 | id_len == NCCL_UNIQUE_ID_BYTES, |
110 | "invalid unqiue_id (expected %d bytes, got %zd)" , |
111 | NCCL_UNIQUE_ID_BYTES, |
112 | id_len); |
113 | |
114 | ncclUniqueId commId; |
115 | memcpy(&commId, id, NCCL_UNIQUE_ID_BYTES); |
116 | ncclComm_t comm; |
117 | { |
118 | pybind11::gil_scoped_release no_gil; |
119 | comm = comm_init_rank(nranks, commId, rank); |
120 | } |
121 | return PyCapsule_New(comm, COMM_CAPSULE_NAME, &destroy_nccl_comm); |
122 | END_HANDLE_TH_ERRORS |
123 | } |
124 | |
125 | PyObject* THCPModule_nccl_reduce(PyObject* self, PyObject* args) { |
126 | HANDLE_TH_ERRORS |
127 | PyObject *_inputs, *_output, *_streams, *_comms; |
128 | int root, op; |
129 | |
130 | if (!PyArg_ParseTuple( |
131 | args, "OOiiOO" , &_inputs, &_output, &root, &op, &_streams, &_comms)) { |
132 | THPUtils_invalidArguments( |
133 | args, |
134 | nullptr, |
135 | "nccl_reduce" , |
136 | 1, |
137 | "(sequence[Tensor] inputs, Tensor output, int root," |
138 | " int op, sequence[torch.cuda.Stream or None]" ); |
139 | return nullptr; |
140 | } |
141 | |
142 | std::vector<at::Tensor> inputs = extract_tensors(_inputs); |
143 | auto output = extract_tensor(_output); |
144 | std::vector<c10::optional<at::cuda::CUDAStream>> streams = |
145 | unpack_streams(_streams, inputs.size()); |
146 | auto user_comms = unpack_comms(_comms, inputs.size()); |
147 | |
148 | { |
149 | pybind11::gil_scoped_release no_gil; |
150 | torch::cuda::nccl::reduce(inputs, output, root, op, streams, user_comms); |
151 | } |
152 | |
153 | Py_RETURN_NONE; |
154 | END_HANDLE_TH_ERRORS |
155 | } |
156 | |
157 | PyObject* THCPModule_nccl_all_reduce(PyObject* self, PyObject* args) { |
158 | HANDLE_TH_ERRORS |
159 | PyObject *_inputs, *_outputs, *_streams, *_comms; |
160 | int op; |
161 | |
162 | if (!PyArg_ParseTuple( |
163 | args, "OOiOO" , &_inputs, &_outputs, &op, &_streams, &_comms)) { |
164 | THPUtils_invalidArguments( |
165 | args, |
166 | nullptr, |
167 | "nccl_all_reduce" , |
168 | 1, |
169 | "(sequence[Tensor] inputs, sequence[Tensor] outputs, int op," |
170 | " sequence[torch.cuda.Stream] streams," |
171 | " sequence[torch.cuda.nccl.Communicator] comms)" ); |
172 | return nullptr; |
173 | } |
174 | |
175 | std::vector<at::Tensor> inputs = extract_tensors(_inputs); |
176 | std::vector<at::Tensor> outputs = extract_tensors(_outputs); |
177 | auto streams = unpack_streams(_streams, inputs.size()); |
178 | auto user_comms = unpack_comms(_comms, inputs.size()); |
179 | |
180 | { |
181 | pybind11::gil_scoped_release no_gil; |
182 | all_reduce(inputs, outputs, op, streams, user_comms); |
183 | } |
184 | |
185 | Py_RETURN_NONE; |
186 | END_HANDLE_TH_ERRORS |
187 | } |
188 | |
189 | PyObject* THCPModule_nccl_broadcast(PyObject* self, PyObject* args) { |
190 | HANDLE_TH_ERRORS |
191 | PyObject *_inputs, *_streams, *_comms; |
192 | int root; |
193 | |
194 | if (!PyArg_ParseTuple(args, "OiOO" , &_inputs, &root, &_streams, &_comms)) { |
195 | THPUtils_invalidArguments( |
196 | args, |
197 | nullptr, |
198 | "nccl_broadcast" , |
199 | 1, |
200 | "(sequence[Tensor] inputs, int root" |
201 | " sequence[torch.cuda.Stream] streams," |
202 | " sequence[torch.cuda.nccl.Communicator] comms)" ); |
203 | return nullptr; |
204 | } |
205 | |
206 | std::vector<at::Tensor> inputs = extract_tensors(_inputs); |
207 | THPUtils_assert(root >= 0 && (size_t)root < inputs.size(), "invalid root" ); |
208 | auto streams = unpack_streams(_streams, inputs.size()); |
209 | auto user_comms = unpack_comms(_comms, inputs.size()); |
210 | |
211 | { |
212 | pybind11::gil_scoped_release no_gil; |
213 | torch::cuda::nccl::broadcast(inputs, streams, user_comms); |
214 | } |
215 | |
216 | Py_RETURN_NONE; |
217 | END_HANDLE_TH_ERRORS |
218 | } |
219 | |
220 | PyObject* THCPModule_nccl_all_gather(PyObject* self, PyObject* args) { |
221 | HANDLE_TH_ERRORS |
222 | PyObject *_inputs, *_outputs, *_streams, *_comms; |
223 | |
224 | if (!PyArg_ParseTuple( |
225 | args, "OOOO" , &_inputs, &_outputs, &_streams, &_comms)) { |
226 | THPUtils_invalidArguments( |
227 | args, |
228 | nullptr, |
229 | "nccl_all_gather" , |
230 | 1, |
231 | "(sequence[Tensor] inputs, sequence[Tensor] outputs" |
232 | " sequence[torch.cuda.Stream] streams," |
233 | " sequence[torch.cuda.nccl.Communicator] comms)" ); |
234 | return nullptr; |
235 | } |
236 | |
237 | std::vector<at::Tensor> inputs = extract_tensors(_inputs); |
238 | std::vector<at::Tensor> outputs = extract_tensors(_outputs); |
239 | auto streams = unpack_streams(_streams, inputs.size()); |
240 | auto user_comms = unpack_comms(_comms, inputs.size()); |
241 | |
242 | { |
243 | pybind11::gil_scoped_release no_gil; |
244 | all_gather(inputs, outputs, streams, user_comms); |
245 | } |
246 | |
247 | Py_RETURN_NONE; |
248 | END_HANDLE_TH_ERRORS |
249 | } |
250 | |
251 | PyObject* THCPModule_nccl_reduce_scatter(PyObject* self, PyObject* args) { |
252 | HANDLE_TH_ERRORS |
253 | PyObject *_inputs, *_outputs, *_streams, *_comms; |
254 | int op; |
255 | |
256 | if (!PyArg_ParseTuple( |
257 | args, "OOiOO" , &_inputs, &_outputs, &op, &_streams, &_comms)) { |
258 | THPUtils_invalidArguments( |
259 | args, |
260 | nullptr, |
261 | "nccl_reduce_scatter" , |
262 | 1, |
263 | "(sequence[Tensor] inputs, sequence[Tensor] outputs, int op" |
264 | " sequence[torch.cuda.Stream] streams," |
265 | " sequence[torch.cuda.nccl.Communicator] comms)" ); |
266 | return nullptr; |
267 | } |
268 | |
269 | std::vector<at::Tensor> inputs = extract_tensors(_inputs); |
270 | std::vector<at::Tensor> outputs = extract_tensors(_outputs); |
271 | auto streams = unpack_streams(_streams, inputs.size()); |
272 | auto user_comms = unpack_comms(_comms, inputs.size()); |
273 | |
274 | { |
275 | pybind11::gil_scoped_release no_gil; |
276 | reduce_scatter(inputs, outputs, op, streams, user_comms); |
277 | } |
278 | |
279 | Py_RETURN_NONE; |
280 | END_HANDLE_TH_ERRORS |
281 | } |
282 | |
283 | static inline at::Tensor (PyObject* obj) { |
284 | if (!THPVariable_Check(obj)) { |
285 | throw torch::TypeError("expected Tensor (got %s)" , Py_TYPE(obj)->tp_name); |
286 | } |
287 | return THPVariable_Unpack(obj); |
288 | } |
289 | |
290 | static inline std::vector<at::Tensor> (PyObject* obj) { |
291 | auto seq = THPObjectPtr(PySequence_Fast(obj, "expected a sequence" )); |
292 | if (!seq) |
293 | throw python_error(); |
294 | |
295 | const Py_ssize_t length = PySequence_Fast_GET_SIZE(seq.get()); |
296 | std::vector<at::Tensor> list; |
297 | if (length >= 0) { |
298 | list.reserve(length); |
299 | } |
300 | for (Py_ssize_t i = 0; i < length; i++) { |
301 | PyObject* item = PySequence_Fast_GET_ITEM(seq.get(), i); |
302 | if (!THPVariable_Check(item)) { |
303 | throw torch::TypeError( |
304 | "expected Tensor at %d (got %s)" , (int)i, Py_TYPE(item)->tp_name); |
305 | } |
306 | list.emplace_back(THPVariable_Unpack(item)); |
307 | } |
308 | return list; |
309 | } |
310 | |