1// Copyright (c) Meta Platforms, Inc. and its affiliates.
2// All rights reserved.
3//
4// This source code is licensed under the BSD-style license found in the
5// LICENSE file in the root directory of this source tree.
6
7#include <torch/csrc/distributed/c10d/socket.h>
8
9#include <cstring>
10#include <system_error>
11#include <thread>
12#include <utility>
13#include <vector>
14
15#ifdef _WIN32
16#include <mutex>
17
18#include <winsock2.h>
19#include <ws2tcpip.h>
20#else
21#include <fcntl.h>
22#include <netdb.h>
23#include <netinet/tcp.h>
24#include <poll.h>
25#include <sys/socket.h>
26#include <sys/types.h>
27#include <unistd.h>
28#endif
29
30#include <fmt/chrono.h>
31#include <fmt/format.h>
32
33#include <torch/csrc/distributed/c10d/error.h>
34#include <torch/csrc/distributed/c10d/exception.h>
35#include <torch/csrc/distributed/c10d/logging.h>
36
37#include <c10/util/CallOnce.h>
38
39namespace c10d {
40namespace detail {
41namespace {
42#ifdef _WIN32
43
44// Since Winsock uses the name `WSAPoll` instead of `poll`, we alias it here
45// to avoid #ifdefs in the source code.
46const auto pollFd = ::WSAPoll;
47
48// Winsock's `getsockopt()` and `setsockopt()` functions expect option values to
49// be passed as `char*` instead of `void*`. We wrap them here to avoid redundant
50// casts in the source code.
51int getSocketOption(
52 SOCKET s,
53 int level,
54 int optname,
55 void* optval,
56 int* optlen) {
57 return ::getsockopt(s, level, optname, static_cast<char*>(optval), optlen);
58}
59
60int setSocketOption(
61 SOCKET s,
62 int level,
63 int optname,
64 const void* optval,
65 int optlen) {
66 return ::setsockopt(
67 s, level, optname, static_cast<const char*>(optval), optlen);
68}
69
70// Winsock has its own error codes which differ from Berkeley's. Fortunately the
71// C++ Standard Library on Windows can map them to standard error codes.
72inline std::error_code getSocketError() noexcept {
73 return std::error_code{::WSAGetLastError(), std::system_category()};
74}
75
76inline void setSocketError(int val) noexcept {
77 ::WSASetLastError(val);
78}
79
80#else
81
82const auto pollFd = ::poll;
83
84const auto getSocketOption = ::getsockopt;
85const auto setSocketOption = ::setsockopt;
86
87inline std::error_code getSocketError() noexcept {
88 return lastError();
89}
90
91inline void setSocketError(int val) noexcept {
92 errno = val;
93}
94
95#endif
96
97// Suspends the current thread for the specified duration.
98void delay(std::chrono::seconds d) {
99#ifdef _WIN32
100 std::this_thread::sleep_for(d);
101#else
102 ::timespec req{};
103 req.tv_sec = d.count();
104
105 // The C++ Standard does not specify whether `sleep_for()` should be signal-
106 // aware; therefore, we use the `nanosleep()` syscall.
107 if (::nanosleep(&req, nullptr) != 0) {
108 std::error_code err = getSocketError();
109 // We don't care about error conditions other than EINTR since a failure
110 // here is not critical.
111 if (err == std::errc::interrupted) {
112 throw std::system_error{err};
113 }
114 }
115#endif
116}
117
118class SocketListenOp;
119class SocketConnectOp;
120} // namespace
121
122class SocketImpl {
123 friend class SocketListenOp;
124 friend class SocketConnectOp;
125
126 public:
127#ifdef _WIN32
128 using Handle = SOCKET;
129#else
130 using Handle = int;
131#endif
132
133#ifdef _WIN32
134 static constexpr Handle invalid_socket = INVALID_SOCKET;
135#else
136 static constexpr Handle invalid_socket = -1;
137#endif
138
139 explicit SocketImpl(Handle hnd) noexcept : hnd_{hnd} {}
140
141 SocketImpl(const SocketImpl& other) = delete;
142
143 SocketImpl& operator=(const SocketImpl& other) = delete;
144
145 SocketImpl(SocketImpl&& other) noexcept = delete;
146
147 SocketImpl& operator=(SocketImpl&& other) noexcept = delete;
148
149 ~SocketImpl();
150
151 std::unique_ptr<SocketImpl> accept() const;
152
153 void closeOnExec() noexcept;
154
155 void enableNonBlocking();
156
157 void disableNonBlocking();
158
159 bool enableNoDelay() noexcept;
160
161 bool enableDualStack() noexcept;
162
163#ifndef _WIN32
164 bool enableAddressReuse() noexcept;
165#endif
166
167#ifdef _WIN32
168 bool enableExclusiveAddressUse() noexcept;
169#endif
170
171 std::uint16_t getPort() const;
172
173 Handle handle() const noexcept {
174 return hnd_;
175 }
176
177 private:
178 bool setSocketFlag(int level, int optname, bool value) noexcept;
179
180 Handle hnd_;
181};
182} // namespace detail
183} // namespace c10d
184
185//
186// libfmt formatters for `addrinfo` and `Socket`
187//
188namespace fmt {
189
190template <>
191struct formatter<::addrinfo> {
192 constexpr decltype(auto) parse(format_parse_context& ctx) {
193 return ctx.begin();
194 }
195
196 template <typename FormatContext>
197 decltype(auto) format(const ::addrinfo& addr, FormatContext& ctx) {
198 char host[NI_MAXHOST], port[NI_MAXSERV]; // NOLINT
199
200 int r = ::getnameinfo(
201 addr.ai_addr,
202 addr.ai_addrlen,
203 host,
204 NI_MAXHOST,
205 port,
206 NI_MAXSERV,
207 NI_NUMERICSERV);
208 if (r != 0) {
209 return format_to(ctx.out(), "?UNKNOWN?");
210 }
211
212 if (addr.ai_addr->sa_family == AF_INET) {
213 return format_to(ctx.out(), "{}:{}", host, port);
214 } else {
215 return format_to(ctx.out(), "[{}]:{}", host, port);
216 }
217 }
218};
219
220template <>
221struct formatter<c10d::detail::SocketImpl> {
222 constexpr decltype(auto) parse(format_parse_context& ctx) {
223 return ctx.begin();
224 }
225
226 template <typename FormatContext>
227 decltype(auto) format(
228 const c10d::detail::SocketImpl& socket,
229 FormatContext& ctx) {
230 ::sockaddr_storage addr_s{};
231
232 auto addr_ptr = reinterpret_cast<::sockaddr*>(&addr_s);
233
234 ::socklen_t addr_len = sizeof(addr_s);
235
236 if (::getsockname(socket.handle(), addr_ptr, &addr_len) != 0) {
237 return format_to(ctx.out(), "?UNKNOWN?");
238 }
239
240 ::addrinfo addr{};
241 addr.ai_addr = addr_ptr;
242 addr.ai_addrlen = addr_len;
243
244 return format_to(ctx.out(), "{}", addr);
245 }
246};
247
248} // namespace fmt
249
250namespace c10d {
251namespace detail {
252
253SocketImpl::~SocketImpl() {
254#ifdef _WIN32
255 ::closesocket(hnd_);
256#else
257 ::close(hnd_);
258#endif
259}
260
261std::unique_ptr<SocketImpl> SocketImpl::accept() const {
262 ::sockaddr_storage addr_s{};
263
264 auto addr_ptr = reinterpret_cast<::sockaddr*>(&addr_s);
265
266 ::socklen_t addr_len = sizeof(addr_s);
267
268 Handle hnd = ::accept(hnd_, addr_ptr, &addr_len);
269 if (hnd == invalid_socket) {
270 std::error_code err = getSocketError();
271 if (err == std::errc::interrupted) {
272 throw std::system_error{err};
273 }
274
275 std::string msg{};
276 if (err == std::errc::invalid_argument) {
277 msg = fmt::format(
278 "The server socket on {} is not listening for connections.", *this);
279 } else {
280 msg = fmt::format(
281 "The server socket on {} has failed to accept a connection {}.",
282 *this,
283 err);
284 }
285
286 C10D_ERROR(msg);
287
288 throw SocketError{msg};
289 }
290
291 ::addrinfo addr{};
292 addr.ai_addr = addr_ptr;
293 addr.ai_addrlen = addr_len;
294
295 C10D_DEBUG(
296 "The server socket on {} has accepted a connection from {}.",
297 *this,
298 addr);
299
300 auto impl = std::make_unique<SocketImpl>(hnd);
301
302 // Make sure that we do not "leak" our file descriptors to child processes.
303 impl->closeOnExec();
304
305 if (!impl->enableNoDelay()) {
306 C10D_WARNING(
307 "The no-delay option cannot be enabled for the client socket on {}.",
308 addr);
309 }
310
311 return impl;
312}
313
314void SocketImpl::closeOnExec() noexcept {
315#ifndef _WIN32
316 ::fcntl(hnd_, F_SETFD, FD_CLOEXEC);
317#endif
318}
319
320void SocketImpl::enableNonBlocking() {
321#ifdef _WIN32
322 unsigned long value = 1;
323 if (::ioctlsocket(hnd_, FIONBIO, &value) == 0) {
324 return;
325 }
326#else
327 int flg = ::fcntl(hnd_, F_GETFL);
328 if (flg != -1) {
329 if (::fcntl(hnd_, F_SETFL, flg | O_NONBLOCK) == 0) {
330 return;
331 }
332 }
333#endif
334 throw SocketError{"The socket cannot be switched to non-blocking mode."};
335}
336
337// TODO: Remove once we migrate everything to non-blocking mode.
338void SocketImpl::disableNonBlocking() {
339#ifdef _WIN32
340 unsigned long value = 0;
341 if (::ioctlsocket(hnd_, FIONBIO, &value) == 0) {
342 return;
343 }
344#else
345 int flg = ::fcntl(hnd_, F_GETFL);
346 if (flg != -1) {
347 if (::fcntl(hnd_, F_SETFL, flg & ~O_NONBLOCK) == 0) {
348 return;
349 }
350 }
351#endif
352 throw SocketError{"The socket cannot be switched to blocking mode."};
353}
354
355bool SocketImpl::enableNoDelay() noexcept {
356 return setSocketFlag(IPPROTO_TCP, TCP_NODELAY, true);
357}
358
359bool SocketImpl::enableDualStack() noexcept {
360 return setSocketFlag(IPPROTO_IPV6, IPV6_V6ONLY, false);
361}
362
363#ifndef _WIN32
364bool SocketImpl::enableAddressReuse() noexcept {
365 return setSocketFlag(SOL_SOCKET, SO_REUSEADDR, true);
366}
367#endif
368
369#ifdef _WIN32
370bool SocketImpl::enableExclusiveAddressUse() noexcept {
371 return setSocketFlag(SOL_SOCKET, SO_EXCLUSIVEADDRUSE, true);
372}
373#endif
374
375std::uint16_t SocketImpl::getPort() const {
376 ::sockaddr_storage addr_s{};
377
378 ::socklen_t addr_len = sizeof(addr_s);
379
380 if (::getsockname(hnd_, reinterpret_cast<::sockaddr*>(&addr_s), &addr_len) !=
381 0) {
382 throw SocketError{"The port number of the socket cannot be retrieved."};
383 }
384
385 if (addr_s.ss_family == AF_INET) {
386 return ntohs(reinterpret_cast<::sockaddr_in*>(&addr_s)->sin_port);
387 } else {
388 return ntohs(reinterpret_cast<::sockaddr_in6*>(&addr_s)->sin6_port);
389 }
390}
391
392bool SocketImpl::setSocketFlag(int level, int optname, bool value) noexcept {
393#ifdef _WIN32
394 auto buf = value ? TRUE : FALSE;
395#else
396 auto buf = value ? 1 : 0;
397#endif
398 return setSocketOption(hnd_, level, optname, &buf, sizeof(buf)) == 0;
399}
400
401namespace {
402
403struct addrinfo_delete {
404 void operator()(::addrinfo* addr) const noexcept {
405 ::freeaddrinfo(addr);
406 }
407};
408
409using addrinfo_ptr = std::unique_ptr<::addrinfo, addrinfo_delete>;
410
411class SocketListenOp {
412 public:
413 SocketListenOp(std::uint16_t port, const SocketOptions& opts);
414
415 std::unique_ptr<SocketImpl> run();
416
417 private:
418 bool tryListen(int family);
419
420 bool tryListen(const ::addrinfo& addr);
421
422 template <typename... Args>
423 void recordError(fmt::string_view format, Args&&... args) {
424 auto msg = fmt::vformat(format, fmt::make_format_args(args...));
425
426 C10D_WARNING(msg);
427
428 errors_.emplace_back(std::move(msg));
429 }
430
431 std::string port_;
432 const SocketOptions* opts_;
433 std::vector<std::string> errors_{};
434 std::unique_ptr<SocketImpl> socket_{};
435};
436
437SocketListenOp::SocketListenOp(std::uint16_t port, const SocketOptions& opts)
438 : port_{fmt::to_string(port)}, opts_{&opts} {}
439
440std::unique_ptr<SocketImpl> SocketListenOp::run() {
441 if (opts_->prefer_ipv6()) {
442 C10D_DEBUG("The server socket will attempt to listen on an IPv6 address.");
443 if (tryListen(AF_INET6)) {
444 return std::move(socket_);
445 }
446
447 C10D_DEBUG("The server socket will attempt to listen on an IPv4 address.");
448 if (tryListen(AF_INET)) {
449 return std::move(socket_);
450 }
451 } else {
452 C10D_DEBUG(
453 "The server socket will attempt to listen on an IPv4 or IPv6 address.");
454 if (tryListen(AF_UNSPEC)) {
455 return std::move(socket_);
456 }
457 }
458
459 constexpr auto* msg =
460 "The server socket has failed to listen on any local network address.";
461
462 C10D_ERROR(msg);
463
464 throw SocketError{fmt::format("{} {}", msg, fmt::join(errors_, " "))};
465}
466
467bool SocketListenOp::tryListen(int family) {
468 ::addrinfo hints{}, *naked_result = nullptr;
469
470 hints.ai_flags = AI_PASSIVE | AI_NUMERICSERV;
471 hints.ai_family = family;
472 hints.ai_socktype = SOCK_STREAM;
473
474 int r = ::getaddrinfo(nullptr, port_.c_str(), &hints, &naked_result);
475 if (r != 0) {
476 const char* gai_err = ::gai_strerror(r);
477
478 recordError(
479 "The local {}network addresses cannot be retrieved (gai error: {} - {}).",
480 family == AF_INET ? "IPv4 "
481 : family == AF_INET6 ? "IPv6 "
482 : "",
483 r,
484 gai_err);
485
486 return false;
487 }
488
489 addrinfo_ptr result{naked_result};
490
491 for (::addrinfo* addr = naked_result; addr != nullptr; addr = addr->ai_next) {
492 C10D_DEBUG("The server socket is attempting to listen on {}.", *addr);
493 if (tryListen(*addr)) {
494 return true;
495 }
496 }
497
498 return false;
499}
500
501bool SocketListenOp::tryListen(const ::addrinfo& addr) {
502 SocketImpl::Handle hnd =
503 ::socket(addr.ai_family, addr.ai_socktype, addr.ai_protocol);
504 if (hnd == SocketImpl::invalid_socket) {
505 recordError(
506 "The server socket cannot be initialized on {} {}.",
507 addr,
508 getSocketError());
509
510 return false;
511 }
512
513 socket_ = std::make_unique<SocketImpl>(hnd);
514
515#ifndef _WIN32
516 if (!socket_->enableAddressReuse()) {
517 C10D_WARNING(
518 "The address reuse option cannot be enabled for the server socket on {}.",
519 addr);
520 }
521#endif
522
523#ifdef _WIN32
524 // The SO_REUSEADDR flag has a significantly different behavior on Windows
525 // compared to Unix-like systems. It allows two or more processes to share
526 // the same port simultaneously, which is totally unsafe.
527 //
528 // Here we follow the recommendation of Microsoft and use the non-standard
529 // SO_EXCLUSIVEADDRUSE flag instead.
530 if (!socket_->enableExclusiveAddressUse()) {
531 C10D_WARNING(
532 "The exclusive address use option cannot be enabled for the server socket on {}.",
533 addr);
534 }
535#endif
536
537 // Not all operating systems support dual-stack sockets by default. Since we
538 // wish to use our IPv6 socket for IPv4 communication as well, we explicitly
539 // ask the system to enable it.
540 if (addr.ai_family == AF_INET6 && !socket_->enableDualStack()) {
541 C10D_WARNING(
542 "The server socket does not support IPv4 communication on {}.", addr);
543 }
544
545 if (::bind(socket_->handle(), addr.ai_addr, addr.ai_addrlen) != 0) {
546 recordError(
547 "The server socket has failed to bind to {} {}.",
548 addr,
549 getSocketError());
550
551 return false;
552 }
553
554 // NOLINTNEXTLINE(bugprone-argument-comment)
555 if (::listen(socket_->handle(), /*backlog=*/2048) != 0) {
556 recordError(
557 "The server socket has failed to listen on {} {}.",
558 addr,
559 getSocketError());
560
561 return false;
562 }
563
564 socket_->closeOnExec();
565
566 C10D_INFO("The server socket has started to listen on {}.", addr);
567
568 return true;
569}
570
571class SocketConnectOp {
572 using Clock = std::chrono::steady_clock;
573 using Duration = std::chrono::steady_clock::duration;
574 using TimePoint = std::chrono::time_point<std::chrono::steady_clock>;
575
576 static const std::chrono::seconds delay_duration_;
577
578 enum class ConnectResult { Success, Error, Retry };
579
580 public:
581 SocketConnectOp(
582 const std::string& host,
583 std::uint16_t port,
584 const SocketOptions& opts);
585
586 std::unique_ptr<SocketImpl> run();
587
588 private:
589 bool tryConnect(int family);
590
591 ConnectResult tryConnect(const ::addrinfo& addr);
592
593 ConnectResult tryConnectCore(const ::addrinfo& addr);
594
595 [[noreturn]] void throwTimeoutError() const;
596
597 template <typename... Args>
598 void recordError(fmt::string_view format, Args&&... args) {
599 auto msg = fmt::vformat(format, fmt::make_format_args(args...));
600
601 C10D_WARNING(msg);
602
603 errors_.emplace_back(std::move(msg));
604 }
605
606 const char* host_;
607 std::string port_;
608 const SocketOptions* opts_;
609 TimePoint deadline_{};
610 std::vector<std::string> errors_{};
611 std::unique_ptr<SocketImpl> socket_{};
612};
613
614const std::chrono::seconds SocketConnectOp::delay_duration_{1};
615
616SocketConnectOp::SocketConnectOp(
617 const std::string& host,
618 std::uint16_t port,
619 const SocketOptions& opts)
620 : host_{host.c_str()}, port_{fmt::to_string(port)}, opts_{&opts} {}
621
622std::unique_ptr<SocketImpl> SocketConnectOp::run() {
623 if (opts_->prefer_ipv6()) {
624 C10D_DEBUG(
625 "The client socket will attempt to connect to an IPv6 address of ({}, {}).",
626 host_,
627 port_);
628
629 if (tryConnect(AF_INET6)) {
630 return std::move(socket_);
631 }
632
633 C10D_DEBUG(
634 "The client socket will attempt to connect to an IPv4 address of ({}, {}).",
635 host_,
636 port_);
637
638 if (tryConnect(AF_INET)) {
639 return std::move(socket_);
640 }
641 } else {
642 C10D_DEBUG(
643 "The client socket will attempt to connect to an IPv4 or IPv6 address of ({}, {}).",
644 host_,
645 port_);
646
647 if (tryConnect(AF_UNSPEC)) {
648 return std::move(socket_);
649 }
650 }
651
652 auto msg = fmt::format(
653 "The client socket has failed to connect to any network address of ({}, {}).",
654 host_,
655 port_);
656
657 C10D_ERROR(msg);
658
659 throw SocketError{fmt::format("{} {}", msg, fmt::join(errors_, " "))};
660}
661
662bool SocketConnectOp::tryConnect(int family) {
663 ::addrinfo hints{};
664 hints.ai_flags = AI_V4MAPPED | AI_ALL | AI_NUMERICSERV;
665 hints.ai_family = family;
666 hints.ai_socktype = SOCK_STREAM;
667
668 deadline_ = Clock::now() + opts_->connect_timeout();
669
670 std::size_t retry_attempt = 1;
671
672 bool retry; // NOLINT(cppcoreguidelines-init-variables)
673 do {
674 retry = false;
675
676 errors_.clear();
677
678 ::addrinfo* naked_result = nullptr;
679 // patternlint-disable cpp-dns-deps
680 int r = ::getaddrinfo(host_, port_.c_str(), &hints, &naked_result);
681 if (r != 0) {
682 const char* gai_err = ::gai_strerror(r);
683
684 recordError(
685 "The {}network addresses of ({}, {}) cannot be retrieved (gai error: {} - {}).",
686 family == AF_INET ? "IPv4 "
687 : family == AF_INET6 ? "IPv6 "
688 : "",
689 host_,
690 port_,
691 r,
692 gai_err);
693 retry = true;
694 } else {
695 addrinfo_ptr result{naked_result};
696
697 for (::addrinfo* addr = naked_result; addr != nullptr;
698 addr = addr->ai_next) {
699 C10D_TRACE("The client socket is attempting to connect to {}.", *addr);
700
701 ConnectResult cr = tryConnect(*addr);
702 if (cr == ConnectResult::Success) {
703 return true;
704 }
705
706 if (cr == ConnectResult::Retry) {
707 retry = true;
708 }
709 }
710 }
711
712 if (retry) {
713 if (Clock::now() < deadline_ - delay_duration_) {
714 // Prevent our log output to be too noisy, warn only every 30 seconds.
715 if (retry_attempt == 30) {
716 C10D_INFO(
717 "No socket on ({}, {}) is listening yet, will retry.",
718 host_,
719 port_);
720
721 retry_attempt = 0;
722 }
723
724 // Wait one second to avoid choking the server.
725 delay(delay_duration_);
726
727 retry_attempt++;
728 } else {
729 throwTimeoutError();
730 }
731 }
732 } while (retry);
733
734 return false;
735}
736
737SocketConnectOp::ConnectResult SocketConnectOp::tryConnect(
738 const ::addrinfo& addr) {
739 if (Clock::now() >= deadline_) {
740 throwTimeoutError();
741 }
742
743 SocketImpl::Handle hnd =
744 ::socket(addr.ai_family, addr.ai_socktype, addr.ai_protocol);
745 if (hnd == SocketImpl::invalid_socket) {
746 recordError(
747 "The client socket cannot be initialized to connect to {} {}.",
748 addr,
749 getSocketError());
750
751 return ConnectResult::Error;
752 }
753
754 socket_ = std::make_unique<SocketImpl>(hnd);
755
756 socket_->enableNonBlocking();
757
758 ConnectResult cr = tryConnectCore(addr);
759 if (cr == ConnectResult::Error) {
760 std::error_code err = getSocketError();
761 if (err == std::errc::interrupted) {
762 throw std::system_error{err};
763 }
764
765 // Retry if the server is not yet listening or if its backlog is exhausted.
766 if (err == std::errc::connection_refused ||
767 err == std::errc::connection_reset) {
768 C10D_TRACE(
769 "The server socket on {} is not yet listening {}, will retry.",
770 addr,
771 err);
772
773 return ConnectResult::Retry;
774 } else {
775 recordError(
776 "The client socket has failed to connect to {} {}.", addr, err);
777
778 return ConnectResult::Error;
779 }
780 }
781
782 socket_->closeOnExec();
783
784 // TODO: Remove once we fully migrate to non-blocking mode.
785 socket_->disableNonBlocking();
786
787 C10D_INFO("The client socket has connected to {} on {}.", addr, *socket_);
788
789 if (!socket_->enableNoDelay()) {
790 C10D_WARNING(
791 "The no-delay option cannot be enabled for the client socket on {}.",
792 *socket_);
793 }
794
795 return ConnectResult::Success;
796}
797
798SocketConnectOp::ConnectResult SocketConnectOp::tryConnectCore(
799 const ::addrinfo& addr) {
800 int r = ::connect(socket_->handle(), addr.ai_addr, addr.ai_addrlen);
801 if (r == 0) {
802 return ConnectResult::Success;
803 }
804
805 std::error_code err = getSocketError();
806 if (err == std::errc::already_connected) {
807 return ConnectResult::Success;
808 }
809
810 if (err != std::errc::operation_in_progress &&
811 err != std::errc::operation_would_block) {
812 return ConnectResult::Error;
813 }
814
815 Duration remaining = deadline_ - Clock::now();
816 if (remaining <= Duration::zero()) {
817 throwTimeoutError();
818 }
819
820 ::pollfd pfd{};
821 pfd.fd = socket_->handle();
822 pfd.events = POLLOUT;
823
824 auto ms = std::chrono::duration_cast<std::chrono::milliseconds>(remaining);
825
826 r = pollFd(&pfd, 1, static_cast<int>(ms.count()));
827 if (r == 0) {
828 throwTimeoutError();
829 }
830 if (r == -1) {
831 return ConnectResult::Error;
832 }
833
834 int err_code = 0;
835
836 ::socklen_t err_len = sizeof(int);
837
838 r = getSocketOption(
839 socket_->handle(), SOL_SOCKET, SO_ERROR, &err_code, &err_len);
840 if (r != 0) {
841 return ConnectResult::Error;
842 }
843
844 if (err_code != 0) {
845 setSocketError(err_code);
846
847 return ConnectResult::Error;
848 } else {
849 return ConnectResult::Success;
850 }
851}
852
853void SocketConnectOp::throwTimeoutError() const {
854 auto msg = fmt::format(
855 "The client socket has timed out after {} while trying to connect to ({}, {}).",
856 opts_->connect_timeout(),
857 host_,
858 port_);
859
860 C10D_ERROR(msg);
861
862 throw TimeoutError{msg};
863}
864
865} // namespace
866
867void Socket::initialize() {
868#ifdef _WIN32
869 static c10::once_flag init_flag{};
870
871 // All processes that call socket functions on Windows must first initialize
872 // the Winsock library.
873 c10::call_once(init_flag, []() {
874 WSADATA data{};
875 if (::WSAStartup(MAKEWORD(2, 2), &data) != 0) {
876 throw SocketError{"The initialization of Winsock has failed."};
877 }
878 });
879#endif
880}
881
882Socket Socket::listen(std::uint16_t port, const SocketOptions& opts) {
883 SocketListenOp op{port, opts};
884
885 return Socket{op.run()};
886}
887
888Socket Socket::connect(
889 const std::string& host,
890 std::uint16_t port,
891 const SocketOptions& opts) {
892 SocketConnectOp op{host, port, opts};
893
894 return Socket{op.run()};
895}
896
897Socket::Socket(Socket&& other) noexcept = default;
898
899Socket& Socket::operator=(Socket&& other) noexcept = default;
900
901Socket::~Socket() = default;
902
903Socket Socket::accept() const {
904 if (impl_) {
905 return Socket{impl_->accept()};
906 }
907
908 throw SocketError{"The socket is not initialized."};
909}
910
911int Socket::handle() const noexcept {
912 if (impl_) {
913 return impl_->handle();
914 }
915 return SocketImpl::invalid_socket;
916}
917
918std::uint16_t Socket::port() const {
919 if (impl_) {
920 return impl_->getPort();
921 }
922 return 0;
923}
924
925Socket::Socket(std::unique_ptr<SocketImpl>&& impl) noexcept
926 : impl_{std::move(impl)} {}
927
928} // namespace detail
929
930SocketError::~SocketError() = default;
931
932} // namespace c10d
933