1#include <torch/csrc/cuda/python_nccl.h>
2
3#include <ATen/core/functional.h>
4#include <pybind11/pybind11.h>
5#include <torch/csrc/DynamicTypes.h>
6#include <torch/csrc/Exceptions.h>
7#include <torch/csrc/THP.h>
8#include <torch/csrc/Types.h>
9#include <torch/csrc/cuda/THCP.h>
10#include <torch/csrc/cuda/nccl.h>
11#include <torch/csrc/utils/pybind.h>
12
13#include <c10/cuda/CUDAGuard.h>
14#include <c10/util/irange.h>
15
16#include <sstream>
17#include <unordered_map>
18
19using namespace at;
20using namespace torch;
21using namespace torch::cuda::nccl;
22using namespace torch::cuda::nccl::detail;
23
24static const char* COMM_CAPSULE_NAME = "torch.cuda.nccl.Communicator";
25
26PyObject* THCPModule_nccl_version(PyObject* self, PyObject* args) {
27 return PyInt_FromLong(version());
28}
29
30PyObject* THCPModule_nccl_unique_id(PyObject* self, PyObject* args) {
31 HANDLE_TH_ERRORS
32 ncclUniqueId id;
33 get_unique_id(id);
34 return PyBytes_FromStringAndSize((char*)&id, NCCL_UNIQUE_ID_BYTES);
35 END_HANDLE_TH_ERRORS
36}
37
38static ncclComm_t unpack_nccl_comm(PyObject* capsule) {
39 ncclComm_t comm =
40 (ncclComm_t)PyCapsule_GetPointer(capsule, COMM_CAPSULE_NAME);
41 if (!comm)
42 throw python_error();
43 return comm;
44}
45
46static void destroy_nccl_comm(PyObject* capsule) {
47 HANDLE_TH_ERRORS
48 ncclComm_t comm = unpack_nccl_comm(capsule);
49 {
50 pybind11::gil_scoped_release no_gil;
51 comm_destroy(comm);
52 }
53 END_HANDLE_TH_ERRORS_RET()
54}
55
56static std::vector<c10::optional<at::cuda::CUDAStream>> unpack_streams(
57 PyObject* obj,
58 size_t size) {
59 if (obj == Py_None) {
60 return std::vector<c10::optional<at::cuda::CUDAStream>>(size, c10::nullopt);
61 }
62 auto streams = THPUtils_PySequence_to_CUDAStreamList(obj);
63 if (streams.size() != size) {
64 throw std::runtime_error(
65 "number of streams is not equal to number of inputs");
66 }
67 return streams;
68}
69
70static inline at::Tensor extract_tensor(PyObject* obj);
71static inline std::vector<at::Tensor> extract_tensors(PyObject* obj);
72
73static std::vector<ncclComm_t> unpack_comms(PyObject* obj, size_t size) {
74 if (obj == Py_None) {
75 return std::vector<ncclComm_t>();
76 }
77 std::vector<ncclComm_t> comms;
78 if (PyCapsule_CheckExact(obj)) {
79 comms = {unpack_nccl_comm(obj)};
80 } else {
81 auto seq = THPObjectPtr(PySequence_Fast(obj, "comm is not a sequence"));
82 if (!seq)
83 throw python_error();
84 auto size = PySequence_Fast_GET_SIZE(seq.get());
85 comms = std::vector<ncclComm_t>(size);
86 for (const auto i : c10::irange(size)) {
87 comms[i] = unpack_nccl_comm(PySequence_Fast_GET_ITEM(seq.get(), i));
88 }
89 }
90 if (comms.size() != size) {
91 throw std::runtime_error(
92 "number of communicators is not equal to number of inputs");
93 }
94 return comms;
95}
96
97PyObject* THCPModule_nccl_init_rank(PyObject* self, PyObject* args) {
98 HANDLE_TH_ERRORS
99 int nranks;
100 const char* id;
101 Py_ssize_t id_len;
102 int rank;
103
104 if (!PyArg_ParseTuple(
105 args, "is#i:nccl_init_rank", &nranks, &id, &id_len, &rank)) {
106 return nullptr;
107 }
108 THPUtils_assert(
109 id_len == NCCL_UNIQUE_ID_BYTES,
110 "invalid unqiue_id (expected %d bytes, got %zd)",
111 NCCL_UNIQUE_ID_BYTES,
112 id_len);
113
114 ncclUniqueId commId;
115 memcpy(&commId, id, NCCL_UNIQUE_ID_BYTES);
116 ncclComm_t comm;
117 {
118 pybind11::gil_scoped_release no_gil;
119 comm = comm_init_rank(nranks, commId, rank);
120 }
121 return PyCapsule_New(comm, COMM_CAPSULE_NAME, &destroy_nccl_comm);
122 END_HANDLE_TH_ERRORS
123}
124
125PyObject* THCPModule_nccl_reduce(PyObject* self, PyObject* args) {
126 HANDLE_TH_ERRORS
127 PyObject *_inputs, *_output, *_streams, *_comms;
128 int root, op;
129
130 if (!PyArg_ParseTuple(
131 args, "OOiiOO", &_inputs, &_output, &root, &op, &_streams, &_comms)) {
132 THPUtils_invalidArguments(
133 args,
134 nullptr,
135 "nccl_reduce",
136 1,
137 "(sequence[Tensor] inputs, Tensor output, int root,"
138 " int op, sequence[torch.cuda.Stream or None]");
139 return nullptr;
140 }
141
142 std::vector<at::Tensor> inputs = extract_tensors(_inputs);
143 auto output = extract_tensor(_output);
144 std::vector<c10::optional<at::cuda::CUDAStream>> streams =
145 unpack_streams(_streams, inputs.size());
146 auto user_comms = unpack_comms(_comms, inputs.size());
147
148 {
149 pybind11::gil_scoped_release no_gil;
150 torch::cuda::nccl::reduce(inputs, output, root, op, streams, user_comms);
151 }
152
153 Py_RETURN_NONE;
154 END_HANDLE_TH_ERRORS
155}
156
157PyObject* THCPModule_nccl_all_reduce(PyObject* self, PyObject* args) {
158 HANDLE_TH_ERRORS
159 PyObject *_inputs, *_outputs, *_streams, *_comms;
160 int op;
161
162 if (!PyArg_ParseTuple(
163 args, "OOiOO", &_inputs, &_outputs, &op, &_streams, &_comms)) {
164 THPUtils_invalidArguments(
165 args,
166 nullptr,
167 "nccl_all_reduce",
168 1,
169 "(sequence[Tensor] inputs, sequence[Tensor] outputs, int op,"
170 " sequence[torch.cuda.Stream] streams,"
171 " sequence[torch.cuda.nccl.Communicator] comms)");
172 return nullptr;
173 }
174
175 std::vector<at::Tensor> inputs = extract_tensors(_inputs);
176 std::vector<at::Tensor> outputs = extract_tensors(_outputs);
177 auto streams = unpack_streams(_streams, inputs.size());
178 auto user_comms = unpack_comms(_comms, inputs.size());
179
180 {
181 pybind11::gil_scoped_release no_gil;
182 all_reduce(inputs, outputs, op, streams, user_comms);
183 }
184
185 Py_RETURN_NONE;
186 END_HANDLE_TH_ERRORS
187}
188
189PyObject* THCPModule_nccl_broadcast(PyObject* self, PyObject* args) {
190 HANDLE_TH_ERRORS
191 PyObject *_inputs, *_streams, *_comms;
192 int root;
193
194 if (!PyArg_ParseTuple(args, "OiOO", &_inputs, &root, &_streams, &_comms)) {
195 THPUtils_invalidArguments(
196 args,
197 nullptr,
198 "nccl_broadcast",
199 1,
200 "(sequence[Tensor] inputs, int root"
201 " sequence[torch.cuda.Stream] streams,"
202 " sequence[torch.cuda.nccl.Communicator] comms)");
203 return nullptr;
204 }
205
206 std::vector<at::Tensor> inputs = extract_tensors(_inputs);
207 THPUtils_assert(root >= 0 && (size_t)root < inputs.size(), "invalid root");
208 auto streams = unpack_streams(_streams, inputs.size());
209 auto user_comms = unpack_comms(_comms, inputs.size());
210
211 {
212 pybind11::gil_scoped_release no_gil;
213 torch::cuda::nccl::broadcast(inputs, streams, user_comms);
214 }
215
216 Py_RETURN_NONE;
217 END_HANDLE_TH_ERRORS
218}
219
220PyObject* THCPModule_nccl_all_gather(PyObject* self, PyObject* args) {
221 HANDLE_TH_ERRORS
222 PyObject *_inputs, *_outputs, *_streams, *_comms;
223
224 if (!PyArg_ParseTuple(
225 args, "OOOO", &_inputs, &_outputs, &_streams, &_comms)) {
226 THPUtils_invalidArguments(
227 args,
228 nullptr,
229 "nccl_all_gather",
230 1,
231 "(sequence[Tensor] inputs, sequence[Tensor] outputs"
232 " sequence[torch.cuda.Stream] streams,"
233 " sequence[torch.cuda.nccl.Communicator] comms)");
234 return nullptr;
235 }
236
237 std::vector<at::Tensor> inputs = extract_tensors(_inputs);
238 std::vector<at::Tensor> outputs = extract_tensors(_outputs);
239 auto streams = unpack_streams(_streams, inputs.size());
240 auto user_comms = unpack_comms(_comms, inputs.size());
241
242 {
243 pybind11::gil_scoped_release no_gil;
244 all_gather(inputs, outputs, streams, user_comms);
245 }
246
247 Py_RETURN_NONE;
248 END_HANDLE_TH_ERRORS
249}
250
251PyObject* THCPModule_nccl_reduce_scatter(PyObject* self, PyObject* args) {
252 HANDLE_TH_ERRORS
253 PyObject *_inputs, *_outputs, *_streams, *_comms;
254 int op;
255
256 if (!PyArg_ParseTuple(
257 args, "OOiOO", &_inputs, &_outputs, &op, &_streams, &_comms)) {
258 THPUtils_invalidArguments(
259 args,
260 nullptr,
261 "nccl_reduce_scatter",
262 1,
263 "(sequence[Tensor] inputs, sequence[Tensor] outputs, int op"
264 " sequence[torch.cuda.Stream] streams,"
265 " sequence[torch.cuda.nccl.Communicator] comms)");
266 return nullptr;
267 }
268
269 std::vector<at::Tensor> inputs = extract_tensors(_inputs);
270 std::vector<at::Tensor> outputs = extract_tensors(_outputs);
271 auto streams = unpack_streams(_streams, inputs.size());
272 auto user_comms = unpack_comms(_comms, inputs.size());
273
274 {
275 pybind11::gil_scoped_release no_gil;
276 reduce_scatter(inputs, outputs, op, streams, user_comms);
277 }
278
279 Py_RETURN_NONE;
280 END_HANDLE_TH_ERRORS
281}
282
283static inline at::Tensor extract_tensor(PyObject* obj) {
284 if (!THPVariable_Check(obj)) {
285 throw torch::TypeError("expected Tensor (got %s)", Py_TYPE(obj)->tp_name);
286 }
287 return THPVariable_Unpack(obj);
288}
289
290static inline std::vector<at::Tensor> extract_tensors(PyObject* obj) {
291 auto seq = THPObjectPtr(PySequence_Fast(obj, "expected a sequence"));
292 if (!seq)
293 throw python_error();
294
295 const Py_ssize_t length = PySequence_Fast_GET_SIZE(seq.get());
296 std::vector<at::Tensor> list;
297 if (length >= 0) {
298 list.reserve(length);
299 }
300 for (Py_ssize_t i = 0; i < length; i++) {
301 PyObject* item = PySequence_Fast_GET_ITEM(seq.get(), i);
302 if (!THPVariable_Check(item)) {
303 throw torch::TypeError(
304 "expected Tensor at %d (got %s)", (int)i, Py_TYPE(item)->tp_name);
305 }
306 list.emplace_back(THPVariable_Unpack(item));
307 }
308 return list;
309}
310