1 | // Copyright (c) Meta Platforms, Inc. and its affiliates. |
2 | // All rights reserved. |
3 | // |
4 | // This source code is licensed under the BSD-style license found in the |
5 | // LICENSE file in the root directory of this source tree. |
6 | |
7 | #include <torch/csrc/distributed/c10d/socket.h> |
8 | |
9 | #include <cstring> |
10 | #include <system_error> |
11 | #include <thread> |
12 | #include <utility> |
13 | #include <vector> |
14 | |
15 | #ifdef _WIN32 |
16 | #include <mutex> |
17 | |
18 | #include <winsock2.h> |
19 | #include <ws2tcpip.h> |
20 | #else |
21 | #include <fcntl.h> |
22 | #include <netdb.h> |
23 | #include <netinet/tcp.h> |
24 | #include <poll.h> |
25 | #include <sys/socket.h> |
26 | #include <sys/types.h> |
27 | #include <unistd.h> |
28 | #endif |
29 | |
30 | #include <fmt/chrono.h> |
31 | #include <fmt/format.h> |
32 | |
33 | #include <torch/csrc/distributed/c10d/error.h> |
34 | #include <torch/csrc/distributed/c10d/exception.h> |
35 | #include <torch/csrc/distributed/c10d/logging.h> |
36 | |
37 | #include <c10/util/CallOnce.h> |
38 | |
39 | namespace c10d { |
40 | namespace detail { |
41 | namespace { |
42 | #ifdef _WIN32 |
43 | |
44 | // Since Winsock uses the name `WSAPoll` instead of `poll`, we alias it here |
45 | // to avoid #ifdefs in the source code. |
46 | const auto pollFd = ::WSAPoll; |
47 | |
48 | // Winsock's `getsockopt()` and `setsockopt()` functions expect option values to |
49 | // be passed as `char*` instead of `void*`. We wrap them here to avoid redundant |
50 | // casts in the source code. |
51 | int getSocketOption( |
52 | SOCKET s, |
53 | int level, |
54 | int optname, |
55 | void* optval, |
56 | int* optlen) { |
57 | return ::getsockopt(s, level, optname, static_cast<char*>(optval), optlen); |
58 | } |
59 | |
60 | int setSocketOption( |
61 | SOCKET s, |
62 | int level, |
63 | int optname, |
64 | const void* optval, |
65 | int optlen) { |
66 | return ::setsockopt( |
67 | s, level, optname, static_cast<const char*>(optval), optlen); |
68 | } |
69 | |
70 | // Winsock has its own error codes which differ from Berkeley's. Fortunately the |
71 | // C++ Standard Library on Windows can map them to standard error codes. |
72 | inline std::error_code getSocketError() noexcept { |
73 | return std::error_code{::WSAGetLastError(), std::system_category()}; |
74 | } |
75 | |
76 | inline void setSocketError(int val) noexcept { |
77 | ::WSASetLastError(val); |
78 | } |
79 | |
80 | #else |
81 | |
82 | const auto pollFd = ::poll; |
83 | |
84 | const auto getSocketOption = ::getsockopt; |
85 | const auto setSocketOption = ::setsockopt; |
86 | |
87 | inline std::error_code getSocketError() noexcept { |
88 | return lastError(); |
89 | } |
90 | |
91 | inline void setSocketError(int val) noexcept { |
92 | errno = val; |
93 | } |
94 | |
95 | #endif |
96 | |
97 | // Suspends the current thread for the specified duration. |
98 | void delay(std::chrono::seconds d) { |
99 | #ifdef _WIN32 |
100 | std::this_thread::sleep_for(d); |
101 | #else |
102 | ::timespec req{}; |
103 | req.tv_sec = d.count(); |
104 | |
105 | // The C++ Standard does not specify whether `sleep_for()` should be signal- |
106 | // aware; therefore, we use the `nanosleep()` syscall. |
107 | if (::nanosleep(&req, nullptr) != 0) { |
108 | std::error_code err = getSocketError(); |
109 | // We don't care about error conditions other than EINTR since a failure |
110 | // here is not critical. |
111 | if (err == std::errc::interrupted) { |
112 | throw std::system_error{err}; |
113 | } |
114 | } |
115 | #endif |
116 | } |
117 | |
118 | class SocketListenOp; |
119 | class SocketConnectOp; |
120 | } // namespace |
121 | |
122 | class SocketImpl { |
123 | friend class SocketListenOp; |
124 | friend class SocketConnectOp; |
125 | |
126 | public: |
127 | #ifdef _WIN32 |
128 | using Handle = SOCKET; |
129 | #else |
130 | using Handle = int; |
131 | #endif |
132 | |
133 | #ifdef _WIN32 |
134 | static constexpr Handle invalid_socket = INVALID_SOCKET; |
135 | #else |
136 | static constexpr Handle invalid_socket = -1; |
137 | #endif |
138 | |
139 | explicit SocketImpl(Handle hnd) noexcept : hnd_{hnd} {} |
140 | |
141 | SocketImpl(const SocketImpl& other) = delete; |
142 | |
143 | SocketImpl& operator=(const SocketImpl& other) = delete; |
144 | |
145 | SocketImpl(SocketImpl&& other) noexcept = delete; |
146 | |
147 | SocketImpl& operator=(SocketImpl&& other) noexcept = delete; |
148 | |
149 | ~SocketImpl(); |
150 | |
151 | std::unique_ptr<SocketImpl> accept() const; |
152 | |
153 | void closeOnExec() noexcept; |
154 | |
155 | void enableNonBlocking(); |
156 | |
157 | void disableNonBlocking(); |
158 | |
159 | bool enableNoDelay() noexcept; |
160 | |
161 | bool enableDualStack() noexcept; |
162 | |
163 | #ifndef _WIN32 |
164 | bool enableAddressReuse() noexcept; |
165 | #endif |
166 | |
167 | #ifdef _WIN32 |
168 | bool enableExclusiveAddressUse() noexcept; |
169 | #endif |
170 | |
171 | std::uint16_t getPort() const; |
172 | |
173 | Handle handle() const noexcept { |
174 | return hnd_; |
175 | } |
176 | |
177 | private: |
178 | bool setSocketFlag(int level, int optname, bool value) noexcept; |
179 | |
180 | Handle hnd_; |
181 | }; |
182 | } // namespace detail |
183 | } // namespace c10d |
184 | |
185 | // |
186 | // libfmt formatters for `addrinfo` and `Socket` |
187 | // |
188 | namespace fmt { |
189 | |
190 | template <> |
191 | struct formatter<::addrinfo> { |
192 | constexpr decltype(auto) parse(format_parse_context& ctx) { |
193 | return ctx.begin(); |
194 | } |
195 | |
196 | template <typename FormatContext> |
197 | decltype(auto) format(const ::addrinfo& addr, FormatContext& ctx) { |
198 | char host[NI_MAXHOST], port[NI_MAXSERV]; // NOLINT |
199 | |
200 | int r = ::getnameinfo( |
201 | addr.ai_addr, |
202 | addr.ai_addrlen, |
203 | host, |
204 | NI_MAXHOST, |
205 | port, |
206 | NI_MAXSERV, |
207 | NI_NUMERICSERV); |
208 | if (r != 0) { |
209 | return format_to(ctx.out(), "?UNKNOWN?" ); |
210 | } |
211 | |
212 | if (addr.ai_addr->sa_family == AF_INET) { |
213 | return format_to(ctx.out(), "{}:{}" , host, port); |
214 | } else { |
215 | return format_to(ctx.out(), "[{}]:{}" , host, port); |
216 | } |
217 | } |
218 | }; |
219 | |
220 | template <> |
221 | struct formatter<c10d::detail::SocketImpl> { |
222 | constexpr decltype(auto) parse(format_parse_context& ctx) { |
223 | return ctx.begin(); |
224 | } |
225 | |
226 | template <typename FormatContext> |
227 | decltype(auto) format( |
228 | const c10d::detail::SocketImpl& socket, |
229 | FormatContext& ctx) { |
230 | ::sockaddr_storage addr_s{}; |
231 | |
232 | auto addr_ptr = reinterpret_cast<::sockaddr*>(&addr_s); |
233 | |
234 | ::socklen_t addr_len = sizeof(addr_s); |
235 | |
236 | if (::getsockname(socket.handle(), addr_ptr, &addr_len) != 0) { |
237 | return format_to(ctx.out(), "?UNKNOWN?" ); |
238 | } |
239 | |
240 | ::addrinfo addr{}; |
241 | addr.ai_addr = addr_ptr; |
242 | addr.ai_addrlen = addr_len; |
243 | |
244 | return format_to(ctx.out(), "{}" , addr); |
245 | } |
246 | }; |
247 | |
248 | } // namespace fmt |
249 | |
250 | namespace c10d { |
251 | namespace detail { |
252 | |
253 | SocketImpl::~SocketImpl() { |
254 | #ifdef _WIN32 |
255 | ::closesocket(hnd_); |
256 | #else |
257 | ::close(hnd_); |
258 | #endif |
259 | } |
260 | |
261 | std::unique_ptr<SocketImpl> SocketImpl::accept() const { |
262 | ::sockaddr_storage addr_s{}; |
263 | |
264 | auto addr_ptr = reinterpret_cast<::sockaddr*>(&addr_s); |
265 | |
266 | ::socklen_t addr_len = sizeof(addr_s); |
267 | |
268 | Handle hnd = ::accept(hnd_, addr_ptr, &addr_len); |
269 | if (hnd == invalid_socket) { |
270 | std::error_code err = getSocketError(); |
271 | if (err == std::errc::interrupted) { |
272 | throw std::system_error{err}; |
273 | } |
274 | |
275 | std::string msg{}; |
276 | if (err == std::errc::invalid_argument) { |
277 | msg = fmt::format( |
278 | "The server socket on {} is not listening for connections." , *this); |
279 | } else { |
280 | msg = fmt::format( |
281 | "The server socket on {} has failed to accept a connection {}." , |
282 | *this, |
283 | err); |
284 | } |
285 | |
286 | C10D_ERROR(msg); |
287 | |
288 | throw SocketError{msg}; |
289 | } |
290 | |
291 | ::addrinfo addr{}; |
292 | addr.ai_addr = addr_ptr; |
293 | addr.ai_addrlen = addr_len; |
294 | |
295 | C10D_DEBUG( |
296 | "The server socket on {} has accepted a connection from {}." , |
297 | *this, |
298 | addr); |
299 | |
300 | auto impl = std::make_unique<SocketImpl>(hnd); |
301 | |
302 | // Make sure that we do not "leak" our file descriptors to child processes. |
303 | impl->closeOnExec(); |
304 | |
305 | if (!impl->enableNoDelay()) { |
306 | C10D_WARNING( |
307 | "The no-delay option cannot be enabled for the client socket on {}." , |
308 | addr); |
309 | } |
310 | |
311 | return impl; |
312 | } |
313 | |
314 | void SocketImpl::closeOnExec() noexcept { |
315 | #ifndef _WIN32 |
316 | ::fcntl(hnd_, F_SETFD, FD_CLOEXEC); |
317 | #endif |
318 | } |
319 | |
320 | void SocketImpl::enableNonBlocking() { |
321 | #ifdef _WIN32 |
322 | unsigned long value = 1; |
323 | if (::ioctlsocket(hnd_, FIONBIO, &value) == 0) { |
324 | return; |
325 | } |
326 | #else |
327 | int flg = ::fcntl(hnd_, F_GETFL); |
328 | if (flg != -1) { |
329 | if (::fcntl(hnd_, F_SETFL, flg | O_NONBLOCK) == 0) { |
330 | return; |
331 | } |
332 | } |
333 | #endif |
334 | throw SocketError{"The socket cannot be switched to non-blocking mode." }; |
335 | } |
336 | |
337 | // TODO: Remove once we migrate everything to non-blocking mode. |
338 | void SocketImpl::disableNonBlocking() { |
339 | #ifdef _WIN32 |
340 | unsigned long value = 0; |
341 | if (::ioctlsocket(hnd_, FIONBIO, &value) == 0) { |
342 | return; |
343 | } |
344 | #else |
345 | int flg = ::fcntl(hnd_, F_GETFL); |
346 | if (flg != -1) { |
347 | if (::fcntl(hnd_, F_SETFL, flg & ~O_NONBLOCK) == 0) { |
348 | return; |
349 | } |
350 | } |
351 | #endif |
352 | throw SocketError{"The socket cannot be switched to blocking mode." }; |
353 | } |
354 | |
355 | bool SocketImpl::enableNoDelay() noexcept { |
356 | return setSocketFlag(IPPROTO_TCP, TCP_NODELAY, true); |
357 | } |
358 | |
359 | bool SocketImpl::enableDualStack() noexcept { |
360 | return setSocketFlag(IPPROTO_IPV6, IPV6_V6ONLY, false); |
361 | } |
362 | |
363 | #ifndef _WIN32 |
364 | bool SocketImpl::enableAddressReuse() noexcept { |
365 | return setSocketFlag(SOL_SOCKET, SO_REUSEADDR, true); |
366 | } |
367 | #endif |
368 | |
369 | #ifdef _WIN32 |
370 | bool SocketImpl::enableExclusiveAddressUse() noexcept { |
371 | return setSocketFlag(SOL_SOCKET, SO_EXCLUSIVEADDRUSE, true); |
372 | } |
373 | #endif |
374 | |
375 | std::uint16_t SocketImpl::getPort() const { |
376 | ::sockaddr_storage addr_s{}; |
377 | |
378 | ::socklen_t addr_len = sizeof(addr_s); |
379 | |
380 | if (::getsockname(hnd_, reinterpret_cast<::sockaddr*>(&addr_s), &addr_len) != |
381 | 0) { |
382 | throw SocketError{"The port number of the socket cannot be retrieved." }; |
383 | } |
384 | |
385 | if (addr_s.ss_family == AF_INET) { |
386 | return ntohs(reinterpret_cast<::sockaddr_in*>(&addr_s)->sin_port); |
387 | } else { |
388 | return ntohs(reinterpret_cast<::sockaddr_in6*>(&addr_s)->sin6_port); |
389 | } |
390 | } |
391 | |
392 | bool SocketImpl::setSocketFlag(int level, int optname, bool value) noexcept { |
393 | #ifdef _WIN32 |
394 | auto buf = value ? TRUE : FALSE; |
395 | #else |
396 | auto buf = value ? 1 : 0; |
397 | #endif |
398 | return setSocketOption(hnd_, level, optname, &buf, sizeof(buf)) == 0; |
399 | } |
400 | |
401 | namespace { |
402 | |
403 | struct addrinfo_delete { |
404 | void operator()(::addrinfo* addr) const noexcept { |
405 | ::freeaddrinfo(addr); |
406 | } |
407 | }; |
408 | |
409 | using addrinfo_ptr = std::unique_ptr<::addrinfo, addrinfo_delete>; |
410 | |
411 | class SocketListenOp { |
412 | public: |
413 | SocketListenOp(std::uint16_t port, const SocketOptions& opts); |
414 | |
415 | std::unique_ptr<SocketImpl> run(); |
416 | |
417 | private: |
418 | bool tryListen(int family); |
419 | |
420 | bool tryListen(const ::addrinfo& addr); |
421 | |
422 | template <typename... Args> |
423 | void recordError(fmt::string_view format, Args&&... args) { |
424 | auto msg = fmt::vformat(format, fmt::make_format_args(args...)); |
425 | |
426 | C10D_WARNING(msg); |
427 | |
428 | errors_.emplace_back(std::move(msg)); |
429 | } |
430 | |
431 | std::string port_; |
432 | const SocketOptions* opts_; |
433 | std::vector<std::string> errors_{}; |
434 | std::unique_ptr<SocketImpl> socket_{}; |
435 | }; |
436 | |
437 | SocketListenOp::SocketListenOp(std::uint16_t port, const SocketOptions& opts) |
438 | : port_{fmt::to_string(port)}, opts_{&opts} {} |
439 | |
440 | std::unique_ptr<SocketImpl> SocketListenOp::run() { |
441 | if (opts_->prefer_ipv6()) { |
442 | C10D_DEBUG("The server socket will attempt to listen on an IPv6 address." ); |
443 | if (tryListen(AF_INET6)) { |
444 | return std::move(socket_); |
445 | } |
446 | |
447 | C10D_DEBUG("The server socket will attempt to listen on an IPv4 address." ); |
448 | if (tryListen(AF_INET)) { |
449 | return std::move(socket_); |
450 | } |
451 | } else { |
452 | C10D_DEBUG( |
453 | "The server socket will attempt to listen on an IPv4 or IPv6 address." ); |
454 | if (tryListen(AF_UNSPEC)) { |
455 | return std::move(socket_); |
456 | } |
457 | } |
458 | |
459 | constexpr auto* msg = |
460 | "The server socket has failed to listen on any local network address." ; |
461 | |
462 | C10D_ERROR(msg); |
463 | |
464 | throw SocketError{fmt::format("{} {}" , msg, fmt::join(errors_, " " ))}; |
465 | } |
466 | |
467 | bool SocketListenOp::tryListen(int family) { |
468 | ::addrinfo hints{}, *naked_result = nullptr; |
469 | |
470 | hints.ai_flags = AI_PASSIVE | AI_NUMERICSERV; |
471 | hints.ai_family = family; |
472 | hints.ai_socktype = SOCK_STREAM; |
473 | |
474 | int r = ::getaddrinfo(nullptr, port_.c_str(), &hints, &naked_result); |
475 | if (r != 0) { |
476 | const char* gai_err = ::gai_strerror(r); |
477 | |
478 | recordError( |
479 | "The local {}network addresses cannot be retrieved (gai error: {} - {})." , |
480 | family == AF_INET ? "IPv4 " |
481 | : family == AF_INET6 ? "IPv6 " |
482 | : "" , |
483 | r, |
484 | gai_err); |
485 | |
486 | return false; |
487 | } |
488 | |
489 | addrinfo_ptr result{naked_result}; |
490 | |
491 | for (::addrinfo* addr = naked_result; addr != nullptr; addr = addr->ai_next) { |
492 | C10D_DEBUG("The server socket is attempting to listen on {}." , *addr); |
493 | if (tryListen(*addr)) { |
494 | return true; |
495 | } |
496 | } |
497 | |
498 | return false; |
499 | } |
500 | |
501 | bool SocketListenOp::tryListen(const ::addrinfo& addr) { |
502 | SocketImpl::Handle hnd = |
503 | ::socket(addr.ai_family, addr.ai_socktype, addr.ai_protocol); |
504 | if (hnd == SocketImpl::invalid_socket) { |
505 | recordError( |
506 | "The server socket cannot be initialized on {} {}." , |
507 | addr, |
508 | getSocketError()); |
509 | |
510 | return false; |
511 | } |
512 | |
513 | socket_ = std::make_unique<SocketImpl>(hnd); |
514 | |
515 | #ifndef _WIN32 |
516 | if (!socket_->enableAddressReuse()) { |
517 | C10D_WARNING( |
518 | "The address reuse option cannot be enabled for the server socket on {}." , |
519 | addr); |
520 | } |
521 | #endif |
522 | |
523 | #ifdef _WIN32 |
524 | // The SO_REUSEADDR flag has a significantly different behavior on Windows |
525 | // compared to Unix-like systems. It allows two or more processes to share |
526 | // the same port simultaneously, which is totally unsafe. |
527 | // |
528 | // Here we follow the recommendation of Microsoft and use the non-standard |
529 | // SO_EXCLUSIVEADDRUSE flag instead. |
530 | if (!socket_->enableExclusiveAddressUse()) { |
531 | C10D_WARNING( |
532 | "The exclusive address use option cannot be enabled for the server socket on {}." , |
533 | addr); |
534 | } |
535 | #endif |
536 | |
537 | // Not all operating systems support dual-stack sockets by default. Since we |
538 | // wish to use our IPv6 socket for IPv4 communication as well, we explicitly |
539 | // ask the system to enable it. |
540 | if (addr.ai_family == AF_INET6 && !socket_->enableDualStack()) { |
541 | C10D_WARNING( |
542 | "The server socket does not support IPv4 communication on {}." , addr); |
543 | } |
544 | |
545 | if (::bind(socket_->handle(), addr.ai_addr, addr.ai_addrlen) != 0) { |
546 | recordError( |
547 | "The server socket has failed to bind to {} {}." , |
548 | addr, |
549 | getSocketError()); |
550 | |
551 | return false; |
552 | } |
553 | |
554 | // NOLINTNEXTLINE(bugprone-argument-comment) |
555 | if (::listen(socket_->handle(), /*backlog=*/2048) != 0) { |
556 | recordError( |
557 | "The server socket has failed to listen on {} {}." , |
558 | addr, |
559 | getSocketError()); |
560 | |
561 | return false; |
562 | } |
563 | |
564 | socket_->closeOnExec(); |
565 | |
566 | C10D_INFO("The server socket has started to listen on {}." , addr); |
567 | |
568 | return true; |
569 | } |
570 | |
571 | class SocketConnectOp { |
572 | using Clock = std::chrono::steady_clock; |
573 | using Duration = std::chrono::steady_clock::duration; |
574 | using TimePoint = std::chrono::time_point<std::chrono::steady_clock>; |
575 | |
576 | static const std::chrono::seconds delay_duration_; |
577 | |
578 | enum class ConnectResult { Success, Error, Retry }; |
579 | |
580 | public: |
581 | SocketConnectOp( |
582 | const std::string& host, |
583 | std::uint16_t port, |
584 | const SocketOptions& opts); |
585 | |
586 | std::unique_ptr<SocketImpl> run(); |
587 | |
588 | private: |
589 | bool tryConnect(int family); |
590 | |
591 | ConnectResult tryConnect(const ::addrinfo& addr); |
592 | |
593 | ConnectResult tryConnectCore(const ::addrinfo& addr); |
594 | |
595 | [[noreturn]] void throwTimeoutError() const; |
596 | |
597 | template <typename... Args> |
598 | void recordError(fmt::string_view format, Args&&... args) { |
599 | auto msg = fmt::vformat(format, fmt::make_format_args(args...)); |
600 | |
601 | C10D_WARNING(msg); |
602 | |
603 | errors_.emplace_back(std::move(msg)); |
604 | } |
605 | |
606 | const char* host_; |
607 | std::string port_; |
608 | const SocketOptions* opts_; |
609 | TimePoint deadline_{}; |
610 | std::vector<std::string> errors_{}; |
611 | std::unique_ptr<SocketImpl> socket_{}; |
612 | }; |
613 | |
614 | const std::chrono::seconds SocketConnectOp::delay_duration_{1}; |
615 | |
616 | SocketConnectOp::SocketConnectOp( |
617 | const std::string& host, |
618 | std::uint16_t port, |
619 | const SocketOptions& opts) |
620 | : host_{host.c_str()}, port_{fmt::to_string(port)}, opts_{&opts} {} |
621 | |
622 | std::unique_ptr<SocketImpl> SocketConnectOp::run() { |
623 | if (opts_->prefer_ipv6()) { |
624 | C10D_DEBUG( |
625 | "The client socket will attempt to connect to an IPv6 address of ({}, {})." , |
626 | host_, |
627 | port_); |
628 | |
629 | if (tryConnect(AF_INET6)) { |
630 | return std::move(socket_); |
631 | } |
632 | |
633 | C10D_DEBUG( |
634 | "The client socket will attempt to connect to an IPv4 address of ({}, {})." , |
635 | host_, |
636 | port_); |
637 | |
638 | if (tryConnect(AF_INET)) { |
639 | return std::move(socket_); |
640 | } |
641 | } else { |
642 | C10D_DEBUG( |
643 | "The client socket will attempt to connect to an IPv4 or IPv6 address of ({}, {})." , |
644 | host_, |
645 | port_); |
646 | |
647 | if (tryConnect(AF_UNSPEC)) { |
648 | return std::move(socket_); |
649 | } |
650 | } |
651 | |
652 | auto msg = fmt::format( |
653 | "The client socket has failed to connect to any network address of ({}, {})." , |
654 | host_, |
655 | port_); |
656 | |
657 | C10D_ERROR(msg); |
658 | |
659 | throw SocketError{fmt::format("{} {}" , msg, fmt::join(errors_, " " ))}; |
660 | } |
661 | |
662 | bool SocketConnectOp::tryConnect(int family) { |
663 | ::addrinfo hints{}; |
664 | hints.ai_flags = AI_V4MAPPED | AI_ALL | AI_NUMERICSERV; |
665 | hints.ai_family = family; |
666 | hints.ai_socktype = SOCK_STREAM; |
667 | |
668 | deadline_ = Clock::now() + opts_->connect_timeout(); |
669 | |
670 | std::size_t retry_attempt = 1; |
671 | |
672 | bool retry; // NOLINT(cppcoreguidelines-init-variables) |
673 | do { |
674 | retry = false; |
675 | |
676 | errors_.clear(); |
677 | |
678 | ::addrinfo* naked_result = nullptr; |
679 | // patternlint-disable cpp-dns-deps |
680 | int r = ::getaddrinfo(host_, port_.c_str(), &hints, &naked_result); |
681 | if (r != 0) { |
682 | const char* gai_err = ::gai_strerror(r); |
683 | |
684 | recordError( |
685 | "The {}network addresses of ({}, {}) cannot be retrieved (gai error: {} - {})." , |
686 | family == AF_INET ? "IPv4 " |
687 | : family == AF_INET6 ? "IPv6 " |
688 | : "" , |
689 | host_, |
690 | port_, |
691 | r, |
692 | gai_err); |
693 | retry = true; |
694 | } else { |
695 | addrinfo_ptr result{naked_result}; |
696 | |
697 | for (::addrinfo* addr = naked_result; addr != nullptr; |
698 | addr = addr->ai_next) { |
699 | C10D_TRACE("The client socket is attempting to connect to {}." , *addr); |
700 | |
701 | ConnectResult cr = tryConnect(*addr); |
702 | if (cr == ConnectResult::Success) { |
703 | return true; |
704 | } |
705 | |
706 | if (cr == ConnectResult::Retry) { |
707 | retry = true; |
708 | } |
709 | } |
710 | } |
711 | |
712 | if (retry) { |
713 | if (Clock::now() < deadline_ - delay_duration_) { |
714 | // Prevent our log output to be too noisy, warn only every 30 seconds. |
715 | if (retry_attempt == 30) { |
716 | C10D_INFO( |
717 | "No socket on ({}, {}) is listening yet, will retry." , |
718 | host_, |
719 | port_); |
720 | |
721 | retry_attempt = 0; |
722 | } |
723 | |
724 | // Wait one second to avoid choking the server. |
725 | delay(delay_duration_); |
726 | |
727 | retry_attempt++; |
728 | } else { |
729 | throwTimeoutError(); |
730 | } |
731 | } |
732 | } while (retry); |
733 | |
734 | return false; |
735 | } |
736 | |
737 | SocketConnectOp::ConnectResult SocketConnectOp::tryConnect( |
738 | const ::addrinfo& addr) { |
739 | if (Clock::now() >= deadline_) { |
740 | throwTimeoutError(); |
741 | } |
742 | |
743 | SocketImpl::Handle hnd = |
744 | ::socket(addr.ai_family, addr.ai_socktype, addr.ai_protocol); |
745 | if (hnd == SocketImpl::invalid_socket) { |
746 | recordError( |
747 | "The client socket cannot be initialized to connect to {} {}." , |
748 | addr, |
749 | getSocketError()); |
750 | |
751 | return ConnectResult::Error; |
752 | } |
753 | |
754 | socket_ = std::make_unique<SocketImpl>(hnd); |
755 | |
756 | socket_->enableNonBlocking(); |
757 | |
758 | ConnectResult cr = tryConnectCore(addr); |
759 | if (cr == ConnectResult::Error) { |
760 | std::error_code err = getSocketError(); |
761 | if (err == std::errc::interrupted) { |
762 | throw std::system_error{err}; |
763 | } |
764 | |
765 | // Retry if the server is not yet listening or if its backlog is exhausted. |
766 | if (err == std::errc::connection_refused || |
767 | err == std::errc::connection_reset) { |
768 | C10D_TRACE( |
769 | "The server socket on {} is not yet listening {}, will retry." , |
770 | addr, |
771 | err); |
772 | |
773 | return ConnectResult::Retry; |
774 | } else { |
775 | recordError( |
776 | "The client socket has failed to connect to {} {}." , addr, err); |
777 | |
778 | return ConnectResult::Error; |
779 | } |
780 | } |
781 | |
782 | socket_->closeOnExec(); |
783 | |
784 | // TODO: Remove once we fully migrate to non-blocking mode. |
785 | socket_->disableNonBlocking(); |
786 | |
787 | C10D_INFO("The client socket has connected to {} on {}." , addr, *socket_); |
788 | |
789 | if (!socket_->enableNoDelay()) { |
790 | C10D_WARNING( |
791 | "The no-delay option cannot be enabled for the client socket on {}." , |
792 | *socket_); |
793 | } |
794 | |
795 | return ConnectResult::Success; |
796 | } |
797 | |
798 | SocketConnectOp::ConnectResult SocketConnectOp::tryConnectCore( |
799 | const ::addrinfo& addr) { |
800 | int r = ::connect(socket_->handle(), addr.ai_addr, addr.ai_addrlen); |
801 | if (r == 0) { |
802 | return ConnectResult::Success; |
803 | } |
804 | |
805 | std::error_code err = getSocketError(); |
806 | if (err == std::errc::already_connected) { |
807 | return ConnectResult::Success; |
808 | } |
809 | |
810 | if (err != std::errc::operation_in_progress && |
811 | err != std::errc::operation_would_block) { |
812 | return ConnectResult::Error; |
813 | } |
814 | |
815 | Duration remaining = deadline_ - Clock::now(); |
816 | if (remaining <= Duration::zero()) { |
817 | throwTimeoutError(); |
818 | } |
819 | |
820 | ::pollfd pfd{}; |
821 | pfd.fd = socket_->handle(); |
822 | pfd.events = POLLOUT; |
823 | |
824 | auto ms = std::chrono::duration_cast<std::chrono::milliseconds>(remaining); |
825 | |
826 | r = pollFd(&pfd, 1, static_cast<int>(ms.count())); |
827 | if (r == 0) { |
828 | throwTimeoutError(); |
829 | } |
830 | if (r == -1) { |
831 | return ConnectResult::Error; |
832 | } |
833 | |
834 | int err_code = 0; |
835 | |
836 | ::socklen_t err_len = sizeof(int); |
837 | |
838 | r = getSocketOption( |
839 | socket_->handle(), SOL_SOCKET, SO_ERROR, &err_code, &err_len); |
840 | if (r != 0) { |
841 | return ConnectResult::Error; |
842 | } |
843 | |
844 | if (err_code != 0) { |
845 | setSocketError(err_code); |
846 | |
847 | return ConnectResult::Error; |
848 | } else { |
849 | return ConnectResult::Success; |
850 | } |
851 | } |
852 | |
853 | void SocketConnectOp::throwTimeoutError() const { |
854 | auto msg = fmt::format( |
855 | "The client socket has timed out after {} while trying to connect to ({}, {})." , |
856 | opts_->connect_timeout(), |
857 | host_, |
858 | port_); |
859 | |
860 | C10D_ERROR(msg); |
861 | |
862 | throw TimeoutError{msg}; |
863 | } |
864 | |
865 | } // namespace |
866 | |
867 | void Socket::initialize() { |
868 | #ifdef _WIN32 |
869 | static c10::once_flag init_flag{}; |
870 | |
871 | // All processes that call socket functions on Windows must first initialize |
872 | // the Winsock library. |
873 | c10::call_once(init_flag, []() { |
874 | WSADATA data{}; |
875 | if (::WSAStartup(MAKEWORD(2, 2), &data) != 0) { |
876 | throw SocketError{"The initialization of Winsock has failed." }; |
877 | } |
878 | }); |
879 | #endif |
880 | } |
881 | |
882 | Socket Socket::listen(std::uint16_t port, const SocketOptions& opts) { |
883 | SocketListenOp op{port, opts}; |
884 | |
885 | return Socket{op.run()}; |
886 | } |
887 | |
888 | Socket Socket::connect( |
889 | const std::string& host, |
890 | std::uint16_t port, |
891 | const SocketOptions& opts) { |
892 | SocketConnectOp op{host, port, opts}; |
893 | |
894 | return Socket{op.run()}; |
895 | } |
896 | |
897 | Socket::Socket(Socket&& other) noexcept = default; |
898 | |
899 | Socket& Socket::operator=(Socket&& other) noexcept = default; |
900 | |
901 | Socket::~Socket() = default; |
902 | |
903 | Socket Socket::accept() const { |
904 | if (impl_) { |
905 | return Socket{impl_->accept()}; |
906 | } |
907 | |
908 | throw SocketError{"The socket is not initialized." }; |
909 | } |
910 | |
911 | int Socket::handle() const noexcept { |
912 | if (impl_) { |
913 | return impl_->handle(); |
914 | } |
915 | return SocketImpl::invalid_socket; |
916 | } |
917 | |
918 | std::uint16_t Socket::port() const { |
919 | if (impl_) { |
920 | return impl_->getPort(); |
921 | } |
922 | return 0; |
923 | } |
924 | |
925 | Socket::Socket(std::unique_ptr<SocketImpl>&& impl) noexcept |
926 | : impl_{std::move(impl)} {} |
927 | |
928 | } // namespace detail |
929 | |
930 | SocketError::~SocketError() = default; |
931 | |
932 | } // namespace c10d |
933 | |