1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file socket.h
22 * \brief this file aims to provide a wrapper of sockets
23 * \author Tianqi Chen
24 */
25#ifndef TVM_SUPPORT_SOCKET_H_
26#define TVM_SUPPORT_SOCKET_H_
27
28#if defined(_WIN32)
29
30#ifndef NOMINMAX
31#define NOMINMAX
32#endif
33
34#include <winsock2.h>
35#include <ws2tcpip.h>
36
37#ifdef _MSC_VER
38#pragma comment(lib, "Ws2_32.lib")
39#endif
40#else
41#include <arpa/inet.h>
42#include <errno.h>
43#include <fcntl.h>
44#include <netdb.h>
45#include <netinet/in.h>
46#include <sys/ioctl.h>
47#include <sys/select.h>
48#include <sys/socket.h>
49#include <unistd.h>
50#endif
51#include <tvm/runtime/logging.h>
52#include <tvm/runtime/registry.h>
53
54#include <cstring>
55#include <string>
56#include <unordered_map>
57#include <vector>
58
59#include "../support/ssize.h"
60#include "../support/utils.h"
61
62#if defined(_WIN32)
63static inline int poll(struct pollfd* pfd, int nfds, int timeout) {
64 return WSAPoll(pfd, nfds, timeout);
65}
66#else
67#include <sys/poll.h>
68#endif // defined(_WIN32)
69
70namespace tvm {
71namespace support {
72
73/*!
74 * \brief Get current host name.
75 * \return The hostname.
76 */
77inline std::string GetHostName() {
78 std::string buf;
79 buf.resize(256);
80 ICHECK_NE(gethostname(&buf[0], 256), -1);
81 return std::string(buf.c_str());
82}
83
84/*!
85 * \brief ValidateIP validates an ip address.
86 * \param ip The ip address in string format localhost or x.x.x.x format
87 * \return result of operation.
88 */
89inline bool ValidateIP(std::string ip) {
90 if (ip == "localhost") {
91 return true;
92 }
93 struct sockaddr_in sa_ipv4;
94 struct sockaddr_in6 sa_ipv6;
95 bool is_ipv4 = inet_pton(AF_INET, ip.c_str(), &(sa_ipv4.sin_addr));
96 bool is_ipv6 = inet_pton(AF_INET6, ip.c_str(), &(sa_ipv6.sin6_addr));
97 return is_ipv4 || is_ipv6;
98}
99
100/*!
101 * \brief Common data structure for network address.
102 */
103struct SockAddr {
104 sockaddr_storage addr;
105 SockAddr() {}
106 /*!
107 * \brief construct address by url and port
108 * \param url The url of the address
109 * \param port The port of the address.
110 */
111 SockAddr(const char* url, int port) { this->Set(url, port); }
112
113 /*!
114 * \brief SockAddr Get the socket address from tracker.
115 * \param tracker The url containing the ip and port number. Format is ('192.169.1.100', 9090)
116 * \return SockAddr parsed from url.
117 */
118 explicit SockAddr(const std::string& url) {
119 size_t sep = url.find(",");
120 std::string host = url.substr(2, sep - 3);
121 std::string port = url.substr(sep + 1, url.length() - 1);
122 ICHECK(ValidateIP(host)) << "Url address is not valid " << url;
123 if (host == "localhost") {
124 host = "127.0.0.1";
125 }
126 this->Set(host.c_str(), std::stoi(port));
127 }
128
129 /*!
130 * \brief set the address
131 * \param host the url of the address
132 * \param port the port of address
133 */
134 void Set(const char* host, int port) {
135 addrinfo hints;
136 memset(&hints, 0, sizeof(hints));
137 hints.ai_family = PF_UNSPEC;
138 hints.ai_flags = AI_PASSIVE;
139 hints.ai_socktype = SOCK_STREAM;
140 addrinfo* res = nullptr;
141 int sig = getaddrinfo(host, nullptr, &hints, &res);
142 ICHECK(sig == 0 && res != nullptr) << "cannot obtain address of " << host;
143 switch (res->ai_family) {
144 case AF_INET: {
145 sockaddr_in* addr4 = reinterpret_cast<sockaddr_in*>(&addr);
146 memcpy(addr4, res->ai_addr, res->ai_addrlen);
147 addr4->sin_port = htons(port);
148 addr4->sin_family = AF_INET;
149 } break;
150 case AF_INET6: {
151 sockaddr_in6* addr6 = reinterpret_cast<sockaddr_in6*>(&addr);
152 memcpy(addr6, res->ai_addr, res->ai_addrlen);
153 addr6->sin6_port = htons(port);
154 addr6->sin6_family = AF_INET6;
155 } break;
156 default:
157 ICHECK(false) << "cannot decode address";
158 }
159 freeaddrinfo(res);
160 }
161 /*! \brief return port of the address */
162 int port() const {
163 return ntohs((addr.ss_family == AF_INET6)
164 ? reinterpret_cast<const sockaddr_in6*>(&addr)->sin6_port
165 : reinterpret_cast<const sockaddr_in*>(&addr)->sin_port);
166 }
167 /*! \brief return the ip address family */
168 int ss_family() const { return addr.ss_family; }
169 /*! \return a string representation of the address */
170 std::string AsString() const {
171 std::string buf;
172 buf.resize(256);
173
174 const void* sinx_addr = nullptr;
175 if (addr.ss_family == AF_INET6) {
176 const in6_addr& addr6 = reinterpret_cast<const sockaddr_in6*>(&addr)->sin6_addr;
177 sinx_addr = reinterpret_cast<const void*>(&addr6);
178 } else if (addr.ss_family == AF_INET) {
179 const in_addr& addr4 = reinterpret_cast<const sockaddr_in*>(&addr)->sin_addr;
180 sinx_addr = reinterpret_cast<const void*>(&addr4);
181 } else {
182 ICHECK(false) << "illegal address";
183 }
184
185#ifdef _WIN32
186 const char* s = inet_ntop(addr.ss_family, (PVOID)sinx_addr, // NOLINT(*)
187 &buf[0], buf.length());
188#else
189 const char* s =
190 inet_ntop(addr.ss_family, sinx_addr, &buf[0], static_cast<socklen_t>(buf.length()));
191#endif
192 ICHECK(s != nullptr) << "cannot decode address";
193 std::ostringstream os;
194 os << s << ":" << port();
195 return os.str();
196 }
197};
198/*!
199 * \brief base class containing common operations of TCP and UDP sockets
200 */
201class Socket {
202 public:
203#if defined(_WIN32)
204 using sock_size_t = int;
205 using SockType = SOCKET;
206#else
207 using SockType = int;
208 using sock_size_t = size_t;
209 static constexpr int INVALID_SOCKET = -1;
210#endif
211 /*! \brief the file descriptor of socket */
212 SockType sockfd;
213 /*!
214 * \brief set this socket to use non-blocking mode
215 * \param non_block whether set it to be non-block, if it is false
216 * it will set it back to block mode
217 */
218 void SetNonBlock(bool non_block) {
219#ifdef _WIN32
220 u_long mode = non_block ? 1 : 0;
221 if (ioctlsocket(sockfd, FIONBIO, &mode) != NO_ERROR) {
222 Socket::Error("SetNonBlock");
223 }
224#else
225 int flag = fcntl(sockfd, F_GETFL, 0);
226 if (flag == -1) {
227 Socket::Error("SetNonBlock-1");
228 }
229 if (non_block) {
230 flag |= O_NONBLOCK;
231 } else {
232 flag &= ~O_NONBLOCK;
233 }
234 if (fcntl(sockfd, F_SETFL, flag) == -1) {
235 Socket::Error("SetNonBlock-2");
236 }
237#endif
238 }
239 /*!
240 * \brief bind the socket to an address
241 * \param addr The address to be binded
242 */
243 void Bind(const SockAddr& addr) {
244 if (bind(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
245 (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) : sizeof(sockaddr_in))) ==
246 -1) {
247 Socket::Error("Bind");
248 }
249 }
250 /*!
251 * \brief try bind the socket to host, from start_port to end_port
252 * \param host host address to bind the socket
253 * \param start_port starting port number to try
254 * \param end_port ending port number to try
255 * \return the port successfully bind to, return -1 if failed to bind any port
256 */
257 inline int TryBindHost(std::string host, int start_port, int end_port) {
258 for (int port = start_port; port < end_port; ++port) {
259 SockAddr addr(host.c_str(), port);
260 if (bind(sockfd, reinterpret_cast<sockaddr*>(&addr.addr),
261 (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) : sizeof(sockaddr_in))) ==
262 0) {
263 return port;
264 } else {
265 LOG(WARNING) << "Bind failed to " << host << ":" << port;
266 }
267#if defined(_WIN32)
268 if (WSAGetLastError() != WSAEADDRINUSE) {
269 Socket::Error("TryBindHost");
270 }
271#else
272 if (errno != EADDRINUSE) {
273 Socket::Error("TryBindHost");
274 }
275#endif
276 }
277 return -1;
278 }
279 /*! \brief get last error code if any */
280 int GetSockError() const {
281 int error = 0;
282 socklen_t len = sizeof(error);
283 if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, reinterpret_cast<char*>(&error), &len) != 0) {
284 Error("GetSockError");
285 }
286 return error;
287 }
288 /*! \brief check if anything bad happens */
289 bool BadSocket() const {
290 if (IsClosed()) return true;
291 int err = GetSockError();
292 if (err == EBADF || err == EINTR) return true;
293 return false;
294 }
295 /*! \brief check if socket is already closed */
296 bool IsClosed() const { return sockfd == INVALID_SOCKET; }
297 /*! \brief close the socket */
298 void Close() {
299 if (sockfd != INVALID_SOCKET) {
300#ifdef _WIN32
301 closesocket(sockfd);
302#else
303 close(sockfd);
304#endif
305 sockfd = INVALID_SOCKET;
306 } else {
307 Error("Socket::Close double close the socket or close without create");
308 }
309 }
310 /*!
311 * \return last error of socket operation
312 */
313 static int GetLastError() {
314#ifdef _WIN32
315 return WSAGetLastError();
316#else
317 return errno;
318#endif
319 }
320 /*! \return whether last error was would block */
321 static bool LastErrorWouldBlock() {
322 int errsv = GetLastError();
323#ifdef _WIN32
324 return errsv == WSAEWOULDBLOCK;
325#else
326 return errsv == EAGAIN || errsv == EWOULDBLOCK;
327#endif
328 }
329 /*!
330 * \brief start up the socket module
331 * call this before using the sockets
332 */
333 static void Startup() {
334#ifdef _WIN32
335 WSADATA wsa_data;
336 if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) {
337 Socket::Error("Startup");
338 }
339 if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) {
340 WSACleanup();
341 LOG(FATAL) << "Could not find a usable version of Winsock.dll";
342 }
343#endif
344 }
345 /*!
346 * \brief shutdown the socket module after use, all sockets need to be closed
347 */
348 static void Finalize() {
349#ifdef _WIN32
350 WSACleanup();
351#endif
352 }
353 /*!
354 * \brief Report an socket error.
355 * \param msg The error message.
356 */
357 static void Error(const char* msg) {
358 int errsv = GetLastError();
359#ifdef _WIN32
360 LOG(FATAL) << "Socket " << msg << " Error:WSAError-code=" << errsv;
361#else
362 LOG(FATAL) << "Socket " << msg << " Error:" << strerror(errsv);
363#endif
364 }
365
366 /*!
367 * \brief Call a function and retry if an EINTR error is encountered.
368 *
369 * Socket operations can return EINTR when the interrupt handler
370 * is registered by the execution environment(e.g. python).
371 * We should retry if there is no KeyboardInterrupt recorded in
372 * the environment.
373 *
374 * \note This function is needed to avoid rare interrupt event
375 * in long running server code.
376 *
377 * \param func The function to retry.
378 * \return The return code returned by function f or error_value on retry failure.
379 */
380 template <typename FuncType>
381 ssize_t RetryCallOnEINTR(FuncType func) {
382 ssize_t ret = func();
383 // common path
384 if (ret != -1) return ret;
385 // less common path
386 do {
387 if (GetLastError() == EINTR) {
388 // Call into env check signals to see if there are
389 // environment specific(e.g. python) signal exceptions.
390 // This function will throw an exception if there is
391 // if the process received a signal that requires TVM to return immediately (e.g. SIGINT).
392 runtime::EnvCheckSignals();
393 } else {
394 // other errors
395 return ret;
396 }
397 ret = func();
398 } while (ret == -1);
399 return ret;
400 }
401
402 protected:
403 explicit Socket(SockType sockfd) : sockfd(sockfd) {}
404};
405
406/*!
407 * \brief a wrapper of TCP socket that hopefully be cross platform
408 */
409class TCPSocket : public Socket {
410 public:
411 TCPSocket() : Socket(INVALID_SOCKET) {}
412 /*!
413 * \brief construct a TCP socket from existing descriptor
414 * \param sockfd The descriptor
415 */
416 explicit TCPSocket(SockType sockfd) : Socket(sockfd) {}
417 /*!
418 * \brief enable/disable TCP keepalive
419 * \param keepalive whether to set the keep alive option on
420 */
421 void SetKeepAlive(bool keepalive) {
422 int opt = static_cast<int>(keepalive);
423 if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast<char*>(&opt), sizeof(opt)) <
424 0) {
425 Socket::Error("SetKeepAlive");
426 }
427 }
428 /*!
429 * \brief create the socket, call this before using socket
430 * \param af domain
431 */
432 void Create(int af = PF_INET) {
433 sockfd = socket(af, SOCK_STREAM, 0);
434 if (sockfd == INVALID_SOCKET) {
435 Socket::Error("Create");
436 }
437 }
438 /*!
439 * \brief perform listen of the socket
440 * \param backlog backlog parameter
441 */
442 void Listen(int backlog = 16) { listen(sockfd, backlog); }
443 /*!
444 * \brief get a new connection
445 * \return The accepted socket connection.
446 */
447 TCPSocket Accept() {
448 SockType newfd = RetryCallOnEINTR([&]() { return accept(sockfd, nullptr, nullptr); });
449 if (newfd == INVALID_SOCKET) {
450 Socket::Error("Accept");
451 }
452 return TCPSocket(newfd);
453 }
454 /*!
455 * \brief get a new connection
456 * \param addr client address from which connection accepted
457 * \return The accepted socket connection.
458 */
459 TCPSocket Accept(SockAddr* addr) {
460 socklen_t addrlen = sizeof(addr->addr);
461 SockType newfd = RetryCallOnEINTR(
462 [&]() { return accept(sockfd, reinterpret_cast<sockaddr*>(&addr->addr), &addrlen); });
463 if (newfd == INVALID_SOCKET) {
464 Socket::Error("Accept");
465 }
466 return TCPSocket(newfd);
467 }
468 /*!
469 * \brief decide whether the socket is at OOB mark
470 * \return 1 if at mark, 0 if not, -1 if an error occurred
471 */
472 int AtMark() const {
473#ifdef _WIN32
474 unsigned long atmark; // NOLINT(*)
475 if (ioctlsocket(sockfd, SIOCATMARK, &atmark) != NO_ERROR) return -1;
476#else
477 int atmark;
478 if (ioctl(sockfd, SIOCATMARK, &atmark) == -1) return -1;
479#endif
480 return static_cast<int>(atmark);
481 }
482 /*!
483 * \brief connect to an address
484 * \param addr the address to connect to
485 * \return whether connect is successful
486 */
487 bool Connect(const SockAddr& addr) {
488 return connect(
489 sockfd, reinterpret_cast<const sockaddr*>(&addr.addr),
490 (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) : sizeof(sockaddr_in))) == 0;
491 }
492 /*!
493 * \brief send data using the socket
494 * \param buf_ the pointer to the buffer
495 * \param len the size of the buffer
496 * \param flag extra flags
497 * \return size of data actually sent
498 * return -1 if error occurs
499 */
500 ssize_t Send(const void* buf_, size_t len, int flag = 0) {
501 const char* buf = reinterpret_cast<const char*>(buf_);
502 return RetryCallOnEINTR(
503 [&]() { return send(sockfd, buf, static_cast<sock_size_t>(len), flag); });
504 }
505 /*!
506 * \brief receive data using the socket
507 * \param buf_ the pointer to the buffer
508 * \param len the size of the buffer
509 * \param flags extra flags
510 * \return size of data actually received
511 * return -1 if error occurs
512 */
513 ssize_t Recv(void* buf_, size_t len, int flags = 0) {
514 char* buf = reinterpret_cast<char*>(buf_);
515 return RetryCallOnEINTR(
516 [&]() { return recv(sockfd, buf, static_cast<sock_size_t>(len), flags); });
517 }
518 /*!
519 * \brief perform block write that will attempt to send all data out
520 * can still return smaller than request when error occurs
521 * \param buf_ the pointer to the buffer
522 * \param len the size of the buffer
523 * \return size of data actually sent
524 */
525 size_t SendAll(const void* buf_, size_t len) {
526 const char* buf = reinterpret_cast<const char*>(buf_);
527 size_t ndone = 0;
528 while (ndone < len) {
529 ssize_t ret = RetryCallOnEINTR(
530 [&]() { return send(sockfd, buf, static_cast<ssize_t>(len - ndone), 0); });
531 if (ret == -1) {
532 if (LastErrorWouldBlock()) return ndone;
533 Socket::Error("SendAll");
534 }
535 buf += ret;
536 ndone += ret;
537 }
538 return ndone;
539 }
540 /*!
541 * \brief perform block read that will attempt to read all data
542 * can still return smaller than request when error occurs
543 * \param buf_ the buffer pointer
544 * \param len length of data to recv
545 * \return size of data actually sent
546 */
547 size_t RecvAll(void* buf_, size_t len) {
548 char* buf = reinterpret_cast<char*>(buf_);
549 size_t ndone = 0;
550 while (ndone < len) {
551 ssize_t ret = RetryCallOnEINTR(
552 [&]() { return recv(sockfd, buf, static_cast<sock_size_t>(len - ndone), MSG_WAITALL); });
553 if (ret == -1) {
554 if (LastErrorWouldBlock()) {
555 LOG(FATAL) << "would block";
556 }
557 Socket::Error("RecvAll");
558 }
559 if (ret == 0) return ndone;
560 buf += ret;
561 ndone += ret;
562 }
563 return ndone;
564 }
565 /*!
566 * \brief Send the data to remote.
567 * \param data The data to be sent.
568 */
569 void SendBytes(std::string data) {
570 int datalen = data.length();
571 ICHECK_EQ(SendAll(&datalen, sizeof(datalen)), sizeof(datalen));
572 ICHECK_EQ(SendAll(data.c_str(), datalen), datalen);
573 }
574 /*!
575 * \brief Receive the data to remote.
576 * \return The data received.
577 */
578 std::string RecvBytes() {
579 int datalen = 0;
580 ICHECK_EQ(RecvAll(&datalen, sizeof(datalen)), sizeof(datalen));
581 std::string data;
582 data.resize(datalen);
583 ICHECK_EQ(RecvAll(&data[0], datalen), datalen);
584 return data;
585 }
586};
587
588/*! \brief helper data structure to perform poll */
589struct PollHelper {
590 public:
591 /*!
592 * \brief add file descriptor to watch for read
593 * \param fd file descriptor to be watched
594 */
595 inline void WatchRead(TCPSocket::SockType fd) {
596 auto& pfd = fds[fd];
597 pfd.fd = fd;
598 pfd.events |= POLLIN;
599 }
600 /*!
601 * \brief add file descriptor to watch for write
602 * \param fd file descriptor to be watched
603 */
604 inline void WatchWrite(TCPSocket::SockType fd) {
605 auto& pfd = fds[fd];
606 pfd.fd = fd;
607 pfd.events |= POLLOUT;
608 }
609 /*!
610 * \brief add file descriptor to watch for exception
611 * \param fd file descriptor to be watched
612 */
613 inline void WatchException(TCPSocket::SockType fd) {
614 auto& pfd = fds[fd];
615 pfd.fd = fd;
616 pfd.events |= POLLPRI;
617 }
618 /*!
619 * \brief Check if the descriptor is ready for read
620 * \param fd file descriptor to check status
621 */
622 inline bool CheckRead(TCPSocket::SockType fd) const {
623 const auto& pfd = fds.find(fd);
624 return pfd != fds.end() && ((pfd->second.events & POLLIN) != 0);
625 }
626 /*!
627 * \brief Check if the descriptor is ready for write
628 * \param fd file descriptor to check status
629 */
630 inline bool CheckWrite(TCPSocket::SockType fd) const {
631 const auto& pfd = fds.find(fd);
632 return pfd != fds.end() && ((pfd->second.events & POLLOUT) != 0);
633 }
634 /*!
635 * \brief Check if the descriptor has any exception
636 * \param fd file descriptor to check status
637 */
638 inline bool CheckExcept(TCPSocket::SockType fd) const {
639 const auto& pfd = fds.find(fd);
640 return pfd != fds.end() && ((pfd->second.events & POLLPRI) != 0);
641 }
642 /*!
643 * \brief wait for exception event on a single descriptor
644 * \param fd the file descriptor to wait the event for
645 * \param timeout the timeout counter, can be negative, which means wait until the event happen
646 * \return 1 if success, 0 if timeout, and -1 if error occurs
647 */
648 inline static int WaitExcept(TCPSocket::SockType fd, long timeout = -1) { // NOLINT(*)
649 pollfd pfd;
650 pfd.fd = fd;
651 pfd.events = POLLPRI;
652 return poll(&pfd, 1, timeout);
653 }
654
655 /*!
656 * \brief perform poll on the set defined, read, write, exception
657 * \param timeout specify timeout in milliseconds(ms) if negative, means poll will block
658 * \return
659 */
660 inline void Poll(long timeout = -1) { // NOLINT(*)
661 std::vector<pollfd> fdset;
662 fdset.reserve(fds.size());
663 for (auto kv : fds) {
664 fdset.push_back(kv.second);
665 }
666 int ret = poll(fdset.data(), fdset.size(), timeout);
667 if (ret == -1) {
668 Socket::Error("Poll");
669 } else {
670 for (auto& pfd : fdset) {
671 auto revents = pfd.revents & pfd.events;
672 if (!revents) {
673 fds.erase(pfd.fd);
674 } else {
675 fds[pfd.fd].events = revents;
676 }
677 }
678 }
679 }
680
681 std::unordered_map<TCPSocket::SockType, pollfd> fds;
682};
683
684} // namespace support
685} // namespace tvm
686#endif // TVM_SUPPORT_SOCKET_H_
687