1#pragma once
2
3#include <ATen/ATen.h>
4#include <c10/util/accumulate.h>
5#include <c10/util/irange.h>
6#include <torch/csrc/distributed/c10d/Types.hpp>
7
8#ifdef _WIN32
9#include <winsock2.h>
10#include <ws2tcpip.h>
11typedef SSIZE_T ssize_t;
12#pragma comment(lib, "Ws2_32.lib")
13#else
14#include <fcntl.h>
15#include <netdb.h>
16#include <sys/poll.h>
17#include <sys/socket.h>
18#include <unistd.h>
19#endif
20
21#include <sys/types.h>
22
23#include <chrono>
24#include <cstdint>
25#include <cstdlib>
26#include <functional>
27#include <limits>
28#include <string>
29#include <system_error>
30#include <tuple>
31#include <vector>
32
33namespace c10d {
34
35TORCH_API std::string parse_env(const char* env_var_name);
36
37// Retrieve tensor shapes from a given tensor.
38TORCH_API std::vector<at::Tensor> getTensorShapes(const std::vector<at::Tensor>& tensors);
39
40// Use -2 to represent unset state of env vars
41#define C10D_ENV_NOT_SET -2
42
43// Turns at::IntArrayRef into "(1, 2, 3, 4)".
44inline std::string toString(at::IntArrayRef l) {
45 std::stringstream ss;
46 ss << "(";
47 for (const auto i : c10::irange(l.size())) {
48 if (i > 0) {
49 ss << ", ";
50 }
51 ss << l[i];
52 }
53 ss << ")";
54 return ss.str();
55}
56
57inline std::string toString(const c10::Layout& layout) {
58 std::stringstream ss;
59 ss << layout;
60 return ss.str();
61}
62
63inline void assertSameType(
64 const at::DeprecatedTypeProperties& type,
65 const std::vector<at::Tensor>& tensors) {
66 for (const auto i : c10::irange(tensors.size())) {
67 if (!tensors[i].options().type_equal(type.options())) {
68 const std::string expected = type.toString();
69 const std::string actual = tensors[i].toString();
70 throw std::invalid_argument(
71 "mixed types (" + expected + " and " + actual + ")");
72 }
73 }
74}
75
76inline std::vector<std::string> split(char separator, const std::string& string) {
77 std::vector<std::string> pieces;
78 std::stringstream ss(string);
79 std::string item;
80 while (std::getline(ss, item, separator)) {
81 pieces.push_back(std::move(item));
82 }
83 return pieces;
84}
85
86inline int parseEnvVarInt(const char* envVarName) {
87 char* stringValue = std::getenv(envVarName);
88 if (stringValue != nullptr) {
89 int val;
90 try {
91 val = std::stoi(stringValue);
92 } catch (std::exception& e) {
93 TORCH_CHECK(false,
94 "Invalid value for environment variable: " + std::string(envVarName));
95 }
96 return val;
97 }
98 return C10D_ENV_NOT_SET;
99}
100
101inline int parseEnvVarIntDefault(const char* envVarName, int defaultVal) {
102 int val = parseEnvVarInt(envVarName);
103 if (val == C10D_ENV_NOT_SET)
104 return defaultVal;
105 return val;
106}
107
108inline bool parseEnvVarFlag(const char* envVarName) {
109 int val = parseEnvVarInt(envVarName);
110 if (val == 1) {
111 return true;
112 } else if (val == 0 || val == C10D_ENV_NOT_SET) {
113 return false;
114 }
115 TORCH_CHECK(false,
116 "Invalid value for environment variable: " + std::string(envVarName));
117 return false;
118}
119
120inline void assertSameSizes(
121 const at::IntArrayRef& sizes,
122 const std::vector<at::Tensor>& tensors) {
123 for (const auto i : c10::irange(tensors.size())) {
124 if (!tensors[i].sizes().equals(sizes)) {
125 const auto expected = toString(sizes);
126 const auto actual = toString(tensors[i].sizes());
127 throw std::invalid_argument(
128 "mixed sizes (" + expected + " and " + actual + ")");
129 }
130 }
131}
132
133inline void assertSameSizeAndType(const std::vector<at::Tensor>& tensors) {
134 // Ensure we have at least one tensor
135 if (tensors.empty()) {
136 throw std::invalid_argument("argument is empty");
137 }
138
139 // Ensure all tensors have identical type and shape
140 auto options = tensors[0].options();
141 auto sizes = tensors[0].sizes();
142 for (const auto i : c10::irange(1, tensors.size())) {
143 if (!tensors[i].options().type_equal(options)) {
144 const auto expected = toString(options);
145 const auto actual = toString(tensors[i].options());
146 throw std::invalid_argument(
147 "argument contains mixed types (" + expected + " and " + actual +
148 ")");
149 }
150 if (!tensors[i].sizes().equals(sizes)) {
151 const auto expected = toString(sizes);
152 const auto actual = toString(tensors[i].sizes());
153 throw std::invalid_argument(
154 "argument contains mixed sizes (" + expected + " and " + actual +
155 ")");
156 }
157 }
158}
159
160inline void assertTypeMatch(
161 std::function<void(const std::string&)> fn,
162 const at::DeprecatedTypeProperties& type,
163 const at::ArrayRef<at::Tensor> tensors,
164 size_t index) {
165 if (!tensors[index].options().type_equal(type.options())) {
166 fn("invalid tensor type at index " + std::to_string(index) + " (expected " +
167 type.toString() + ", got " + tensors[index].toString() + ")");
168 }
169}
170
171inline void assertTypeMatch(
172 std::function<void(const std::string&)> fn,
173 const at::TensorOptions& options,
174 const at::ArrayRef<at::Tensor> tensors,
175 size_t index) {
176 if (!tensors[index].options().type_equal(options)) {
177 fn("invalid tensor type at index " + std::to_string(index) + " (expected " +
178 toString(options) + ", got " + toString(tensors[index].options()) + ")");
179 }
180}
181
182inline void assertSizesMatch(
183 std::function<void(const std::string&)> fn,
184 const at::IntArrayRef& sizes,
185 const at::ArrayRef<at::Tensor> tensors,
186 size_t index) {
187 if (tensors[index].sizes() != sizes) {
188 fn("invalid tensor size at index " + std::to_string(index) + " (expected " +
189 toString(sizes) + ", got " + toString(tensors[index].sizes()) + ")");
190 }
191}
192
193inline void assertLayoutMatch(
194 std::function<void(const std::string&)> fn,
195 const c10::Layout& expected,
196 const at::ArrayRef<at::Tensor> tensors,
197 size_t index) {
198 const auto& actual = tensors[index].layout();
199 if (actual != expected) {
200 fn("invalid tensor layout at index " + std::to_string(index) +
201 " (expected " + toString(expected) + ", got " + toString(actual) + ")");
202 }
203}
204
205inline void assertLayoutMatch(
206 std::function<void(const std::string&)> fn,
207 const at::ArrayRef<at::Tensor> tensors) {
208 const auto& layout = tensors[0].layout();
209 for (const auto i : c10::irange(1, tensors.size())) {
210 assertLayoutMatch(fn, layout, tensors, i);
211 }
212}
213
214inline void assertNonEmpty(
215 std::function<void(const std::string&)> fn,
216 const at::ArrayRef<at::Tensor> tensors) {
217 if (tensors.empty()) {
218 fn("requires non-empty tensor list");
219 }
220}
221
222inline void assertSingleElement(
223 std::function<void(const std::string&)> fn,
224 const at::ArrayRef<at::Tensor> tensors) {
225 if (tensors.size() != 1) {
226 fn("requires a single-element tensor list");
227 }
228}
229
230inline void assertSingleElementInput(
231 std::function<void(const std::string&)> fn,
232 const at::ArrayRef<at::Tensor> tensors) {
233 if (tensors.size() != 1) {
234 fn("requires a single-element input tensor list");
235 }
236}
237
238inline void assertSingleElementOutput(
239 std::function<void(const std::string&)> fn,
240 const at::ArrayRef<at::Tensor> tensors) {
241 if (tensors.size() != 1) {
242 fn("requires a single-element output tensor list");
243 }
244}
245
246inline void assertRootRank(
247 std::function<void(const std::string&)> fn,
248 int rank,
249 int size) {
250 if (rank < 0 || rank >= size) {
251 fn("invalid root rank: " + std::to_string(rank));
252 }
253}
254
255inline void assertRootTensor(
256 std::function<void(const std::string&)> fn,
257 int rank,
258 int size) {
259 if (rank < 0 || rank >= size) {
260 fn("invalid root tensor: " + std::to_string(rank));
261 }
262}
263
264inline void assertDense(
265 std::function<void(const std::string&)> fn,
266 const at::ArrayRef<at::Tensor> tensors) {
267 const auto& layout = tensors[0].layout();
268 if (layout != at::kStrided) {
269 fn("only supports dense tensors");
270 }
271}
272
273inline void assertCPU(
274 std::function<void(const std::string&)> fn,
275 const at::ArrayRef<at::Tensor> tensors) {
276 const auto& device = tensors[0].device();
277 if (device.type() != at::kCPU) {
278 fn("only supports CPU tensors");
279 }
280}
281
282inline void assertSameDevice(
283 std::function<void(const std::string&)> fn,
284 const at::ArrayRef<at::Tensor> tensors) {
285 if (tensors.size() < 2) {
286 return;
287 }
288 const auto& device = tensors[0].device();
289 for (const auto i : c10::irange(1, tensors.size())) {
290 if (tensors[i].device() != device) {
291 fn("tensors should be on the same device");
292 }
293 }
294}
295
296inline void assertTypeAndSizesMatch(
297 std::function<void(const std::string&)> fn,
298 const at::ArrayRef<at::Tensor> tensors,
299 const at::DeprecatedTypeProperties& type,
300 const at::IntArrayRef& sizes) {
301 for (const auto i : c10::irange(tensors.size())) {
302 assertTypeMatch(fn, type, tensors, i);
303 assertSizesMatch(fn, sizes, tensors, i);
304 }
305}
306
307inline void assertTypeAndSizesMatch(
308 std::function<void(const std::string&)> fn,
309 const at::ArrayRef<at::Tensor> tensors,
310 const at::TensorOptions& options,
311 const at::IntArrayRef& sizes) {
312 for (const auto i : c10::irange(tensors.size())) {
313 assertTypeMatch(fn, options, tensors, i);
314 assertSizesMatch(fn, sizes, tensors, i);
315 }
316}
317
318inline void assertTypeAndSizesMatch(
319 std::function<void(const std::string&)> fn,
320 const at::ArrayRef<at::Tensor> tensors) {
321 const auto& options = tensors[0].options();
322 const auto sizes = tensors[0].sizes();
323 assertTypeAndSizesMatch(fn, tensors.slice(1), options, sizes);
324}
325
326// Copied from ATen/core/functional.h.
327template <typename F, typename T>
328inline auto fmap(T& inputs, const F& fn)
329 -> std::vector<decltype(fn(*inputs.begin()))> {
330 std::vector<decltype(fn(*inputs.begin()))> r;
331 r.reserve(inputs.size());
332 for (auto& input : inputs) {
333 r.push_back(fn(input));
334 }
335 return r;
336}
337
338// Copied from torch/csrc/utils/tensor_flatten.h.
339inline at::Tensor flattenDenseTensors(at::TensorList tensors) {
340 static const auto flatten = [](const at::Tensor& t) {
341 return t.contiguous().view({-1});
342 };
343 if (tensors.size() == 1) {
344 return flatten(tensors[0]);
345 }
346 return at::cat(::c10d::fmap(tensors, flatten));
347}
348
349inline at::Tensor newLikeFlat(
350 std::vector<std::vector<at::Tensor>>& tensors,
351 size_t deviceIdx) {
352 if (tensors.empty() || tensors[0].empty()) {
353 TORCH_CHECK(false, "Received an empty list");
354 }
355 if (deviceIdx >= tensors.size()) {
356 TORCH_CHECK(false, "Invalid device index");
357 }
358 auto& t = tensors[deviceIdx][0];
359 auto device = t.device();
360 for (const auto i : c10::irange(1, tensors[deviceIdx].size())) {
361 if (tensors[deviceIdx][i].device() != device) {
362 TORCH_CHECK(false, "Expecting all tensors on the same device");
363 }
364 }
365 at::DeviceGuard gpuGuard(device);
366 std::vector<int64_t> sizes{static_cast<int64_t>(tensors[deviceIdx].size())};
367 std::vector<int64_t> strides{static_cast<int64_t>(t.numel())};
368 sizes.insert(sizes.end(), t.sizes().begin(), t.sizes().end());
369 strides.insert(strides.end(), t.strides().begin(), t.strides().end());
370 return at::empty_strided(
371 sizes, strides, t.options().memory_format(c10::nullopt));
372}
373
374inline at::Tensor newLikeFlat(std::vector<at::Tensor>& tensors) {
375 if (tensors.empty()) {
376 TORCH_CHECK(false, "Received an empty list");
377 }
378 auto& t = tensors[0];
379 at::DeviceGuard gpuGuard(t.device());
380 std::vector<int64_t> sizes{static_cast<int64_t>(tensors.size())};
381 sizes.insert(sizes.end(), t.sizes().begin(), t.sizes().end());
382 return at::empty(sizes, t.options());
383}
384
385inline std::vector<std::vector<int64_t>> getSizes(
386 const std::vector<at::Tensor>& tensors) {
387 std::vector<std::vector<int64_t>> sizes(tensors.size());
388 for (const auto i : c10::irange(tensors.size())) {
389 sizes[i] = tensors[i].sizes().vec();
390 }
391 return sizes;
392}
393
394inline std::vector<int> getDevices(const std::vector<at::Tensor>& tensors) {
395 std::vector<int> devices(tensors.size(), -1);
396 if (tensors[0].device().is_cuda()) {
397 for (const auto i : c10::irange(tensors.size())) {
398 devices[i] = tensors[i].storage().device().index();
399 }
400 }
401 return devices;
402}
403
404template <typename T>
405inline T* getDataPointer(const at::Tensor& tensor) {
406 // This method is only used in ProcessGroupGloo for now. Call sites must make
407 // sure that the input tensor is contiguous. It is OK if the tensor does not
408 // start from the beginning of the storage. For example, it could come from
409 // chunk(..., dim=0)[1]. Hence, we need to use data_ptr() instead of
410 // tensor.storage().data()
411 // NB: not using tensor.data<T>() because tensor is not aware of gloo::TYPE
412 return static_cast<T*>(tensor.data_ptr());
413}
414
415template <typename T>
416std::vector<T*> getDataPointers(const std::vector<at::Tensor>& tensors) {
417 std::vector<T*> ptrs(tensors.size());
418 for (const auto i : c10::irange(tensors.size())) {
419 ptrs[i] = getDataPointer<T>(tensors[i]);
420 }
421 return ptrs;
422}
423
424// For alltoall split size sanity check
425inline void checkSplitSizes(
426 const std::vector<int64_t>& split_sizes,
427 const at::Tensor& tensor,
428 int group_size) {
429 if (split_sizes.empty()) {
430 TORCH_CHECK(
431 tensor.size(0) % group_size == 0,
432 "Tensor's dim 0 does not divide equally across group size");
433 } else {
434 TORCH_CHECK(
435 split_sizes.size() == static_cast<size_t>(group_size),
436 "Number of tensor splits not equal to group size");
437 const auto sum = c10::sum_integers(split_sizes);
438 TORCH_CHECK(
439 sum == tensor.size(0), "Split sizes doesn't match total dim 0 size");
440 }
441}
442
443// Compute alltoall lengths and offsets, handling multi-dimension tensors
444template <typename T>
445size_t computeLengthsAndOffsets(
446 const std::vector<int64_t>& split_sizes,
447 const at::Tensor& tensor,
448 std::vector<T>* lengths,
449 std::vector<T>* offsets) {
450 size_t group_size = lengths->size();
451 bool equal_splits = false;
452 size_t dim0_size = tensor.size(0);
453 size_t row_size = (dim0_size ? tensor.numel() / dim0_size : 1);
454 size_t split_size = 0;
455 size_t offset = 0;
456
457 if (split_sizes.empty()) {
458 equal_splits = true;
459 split_size = tensor.size(0) / group_size;
460 }
461 for(const auto i : c10::irange(group_size)) {
462 size_t length = row_size * (equal_splits ? split_size : split_sizes[i]);
463 TORCH_INTERNAL_ASSERT(
464 length <= std::numeric_limits<int>::max() &&
465 offset <= std::numeric_limits<int>::max(),
466 "Length or offset larger than INT_MAX not supported");
467 (*lengths)[i] = length;
468 (*offsets)[i] = offset;
469 offset += length;
470 }
471 return offset;
472}
473
474template <typename T>
475size_t computeLengthsAndOffsets(
476 const std::vector<at::Tensor>& tensors,
477 std::vector<T>* lengths,
478 std::vector<T>* offsets) {
479 size_t group_size = lengths->size();
480 size_t offset = 0;
481 for(const auto i : c10::irange(group_size)) {
482 size_t length = tensors[i].numel();
483 TORCH_INTERNAL_ASSERT(
484 length <= std::numeric_limits<int>::max() &&
485 offset <= std::numeric_limits<int>::max(),
486 "Length or offset larger than INT_MAX not supported");
487 (*lengths)[i] = length;
488 (*offsets)[i] = offset;
489 offset += length;
490 }
491 return offset;
492}
493
494using RankType = uint32_t;
495using SizeType = uint64_t;
496
497// `errno` is only meaningful when it fails. E.g., a successful `fork()` sets
498// `errno` to `EINVAL` in child process on some macos
499// (https://stackoverflow.com/a/20295079), and thus `errno` should really only
500// be inspected if an error occurred.
501//
502// `success_cond` is an expression used to check if an error has happend. So for
503// `fork()`, we can use `SYSCHECK(pid = fork(), pid != -1)`. The function output
504// is stored in variable `__output` and may be used in `success_cond`.
505#ifdef _WIN32
506#define SYSCHECK(expr, success_cond) \
507 while (true) { \
508 auto __output = (expr); \
509 auto errno_local = WSAGetLastError(); \
510 (void)__output; \
511 if (!(success_cond)) { \
512 if (errno == EINTR) { \
513 continue; \
514 } else if ( \
515 errno_local == WSAETIMEDOUT || errno_local == WSAEWOULDBLOCK) { \
516 TORCH_CHECK(false, "Socket Timeout"); \
517 } else { \
518 throw std::system_error(errno_local, std::system_category()); \
519 } \
520 } else { \
521 break; \
522 } \
523 }
524#else
525#define SYSCHECK(expr, success_cond) \
526 while (true) { \
527 auto __output = (expr); \
528 (void)__output; \
529 if (!(success_cond)) { \
530 if (errno == EINTR) { \
531 continue; \
532 } else if (errno == EAGAIN || errno == EWOULDBLOCK) { \
533 TORCH_CHECK(false, "Socket Timeout"); \
534 } else { \
535 throw std::system_error(errno, std::system_category()); \
536 } \
537 } else { \
538 break; \
539 } \
540 }
541#endif
542
543// Most functions indicate error by returning `-1`. This is a helper macro for
544// this common case with `SYSCHECK`.
545// Since SOCKET_ERROR = -1 in MSVC, so also leverage SYSCHECK_ERR_RETURN_NEG1
546#define SYSCHECK_ERR_RETURN_NEG1(expr) SYSCHECK(expr, __output != -1)
547
548namespace tcputil {
549
550// Send and receive
551template <typename T>
552void sendBytes(
553 int socket,
554 const T* buffer,
555 size_t length,
556 bool moreData = false) {
557 size_t bytesToSend = sizeof(T) * length;
558 if (bytesToSend == 0) {
559 return;
560 }
561
562 auto bytes = reinterpret_cast<const uint8_t*>(buffer);
563 uint8_t* currentBytes = const_cast<uint8_t*>(bytes);
564
565 int flags = 0;
566
567#ifdef MSG_MORE
568 if (moreData) { // there is more data to send
569 flags |= MSG_MORE;
570 }
571#endif
572
573// Ignore SIGPIPE as the send() return value is always checked for error
574#ifdef MSG_NOSIGNAL
575 flags |= MSG_NOSIGNAL;
576#endif
577
578 while (bytesToSend > 0) {
579 ssize_t bytesSent;
580 SYSCHECK_ERR_RETURN_NEG1(
581 bytesSent =
582 ::send(socket, (const char*)currentBytes, bytesToSend, flags))
583 if (bytesSent == 0) {
584 throw std::system_error(ECONNRESET, std::system_category());
585 }
586
587 bytesToSend -= bytesSent;
588 currentBytes += bytesSent;
589 }
590}
591
592template <typename T>
593void recvBytes(int socket, T* buffer, size_t length) {
594 size_t bytesToReceive = sizeof(T) * length;
595 if (bytesToReceive == 0) {
596 return;
597 }
598
599 auto bytes = reinterpret_cast<uint8_t*>(buffer);
600 uint8_t* currentBytes = bytes;
601
602 while (bytesToReceive > 0) {
603 ssize_t bytesReceived;
604 SYSCHECK_ERR_RETURN_NEG1(
605 bytesReceived = recv(socket, (char*)currentBytes, bytesToReceive, 0))
606 if (bytesReceived == 0) {
607 throw std::system_error(ECONNRESET, std::system_category());
608 }
609
610 bytesToReceive -= bytesReceived;
611 currentBytes += bytesReceived;
612 }
613}
614
615// send a vector's length and data
616template <typename T>
617void sendVector(int socket, const std::vector<T>& vec, bool moreData = false) {
618 SizeType size = vec.size();
619 sendBytes<SizeType>(socket, &size, 1, true);
620 sendBytes<T>(socket, vec.data(), size, moreData);
621}
622
623// receive a vector as sent in sendVector
624template <typename T>
625std::vector<T> recvVector(int socket) {
626 SizeType valueSize;
627 recvBytes<SizeType>(socket, &valueSize, 1);
628 std::vector<T> value(valueSize);
629 recvBytes<T>(socket, value.data(), value.size());
630 return value;
631}
632
633// this is only for convenience when sending rvalues
634template <typename T>
635void sendValue(int socket, const T& value, bool moreData = false) {
636 sendBytes<T>(socket, &value, 1, moreData);
637}
638
639template <typename T>
640T recvValue(int socket) {
641 T value;
642 recvBytes<T>(socket, &value, 1);
643 return value;
644}
645
646// send a string's length and data
647inline void sendString(
648 int socket,
649 const std::string& str,
650 bool moreData = false) {
651 SizeType size = str.size();
652 sendBytes<SizeType>(socket, &size, 1, true);
653 sendBytes<char>(socket, str.data(), size, moreData);
654}
655
656// receive a string as sent in sendString
657inline std::string recvString(int socket) {
658 SizeType valueSize;
659 recvBytes<SizeType>(socket, &valueSize, 1);
660 std::vector<char> value(valueSize);
661 recvBytes<char>(socket, value.data(), value.size());
662 return std::string(value.data(), value.size());
663}
664
665} // namespace tcputil
666} // namespace c10d
667