1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file socket.h |
22 | * \brief this file aims to provide a wrapper of sockets |
23 | * \author Tianqi Chen |
24 | */ |
25 | #ifndef TVM_SUPPORT_SOCKET_H_ |
26 | #define TVM_SUPPORT_SOCKET_H_ |
27 | |
28 | #if defined(_WIN32) |
29 | |
30 | #ifndef NOMINMAX |
31 | #define NOMINMAX |
32 | #endif |
33 | |
34 | #include <winsock2.h> |
35 | #include <ws2tcpip.h> |
36 | |
37 | #ifdef _MSC_VER |
38 | #pragma comment(lib, "Ws2_32.lib") |
39 | #endif |
40 | #else |
41 | #include <arpa/inet.h> |
42 | #include <errno.h> |
43 | #include <fcntl.h> |
44 | #include <netdb.h> |
45 | #include <netinet/in.h> |
46 | #include <sys/ioctl.h> |
47 | #include <sys/select.h> |
48 | #include <sys/socket.h> |
49 | #include <unistd.h> |
50 | #endif |
51 | #include <tvm/runtime/logging.h> |
52 | #include <tvm/runtime/registry.h> |
53 | |
54 | #include <cstring> |
55 | #include <string> |
56 | #include <unordered_map> |
57 | #include <vector> |
58 | |
59 | #include "../support/ssize.h" |
60 | #include "../support/utils.h" |
61 | |
62 | #if defined(_WIN32) |
63 | static inline int poll(struct pollfd* pfd, int nfds, int timeout) { |
64 | return WSAPoll(pfd, nfds, timeout); |
65 | } |
66 | #else |
67 | #include <sys/poll.h> |
68 | #endif // defined(_WIN32) |
69 | |
70 | namespace tvm { |
71 | namespace support { |
72 | |
73 | /*! |
74 | * \brief Get current host name. |
75 | * \return The hostname. |
76 | */ |
77 | inline std::string GetHostName() { |
78 | std::string buf; |
79 | buf.resize(256); |
80 | ICHECK_NE(gethostname(&buf[0], 256), -1); |
81 | return std::string(buf.c_str()); |
82 | } |
83 | |
84 | /*! |
85 | * \brief ValidateIP validates an ip address. |
86 | * \param ip The ip address in string format localhost or x.x.x.x format |
87 | * \return result of operation. |
88 | */ |
89 | inline bool ValidateIP(std::string ip) { |
90 | if (ip == "localhost" ) { |
91 | return true; |
92 | } |
93 | struct sockaddr_in sa_ipv4; |
94 | struct sockaddr_in6 sa_ipv6; |
95 | bool is_ipv4 = inet_pton(AF_INET, ip.c_str(), &(sa_ipv4.sin_addr)); |
96 | bool is_ipv6 = inet_pton(AF_INET6, ip.c_str(), &(sa_ipv6.sin6_addr)); |
97 | return is_ipv4 || is_ipv6; |
98 | } |
99 | |
100 | /*! |
101 | * \brief Common data structure for network address. |
102 | */ |
103 | struct SockAddr { |
104 | sockaddr_storage addr; |
105 | SockAddr() {} |
106 | /*! |
107 | * \brief construct address by url and port |
108 | * \param url The url of the address |
109 | * \param port The port of the address. |
110 | */ |
111 | SockAddr(const char* url, int port) { this->Set(url, port); } |
112 | |
113 | /*! |
114 | * \brief SockAddr Get the socket address from tracker. |
115 | * \param tracker The url containing the ip and port number. Format is ('192.169.1.100', 9090) |
116 | * \return SockAddr parsed from url. |
117 | */ |
118 | explicit SockAddr(const std::string& url) { |
119 | size_t sep = url.find("," ); |
120 | std::string host = url.substr(2, sep - 3); |
121 | std::string port = url.substr(sep + 1, url.length() - 1); |
122 | ICHECK(ValidateIP(host)) << "Url address is not valid " << url; |
123 | if (host == "localhost" ) { |
124 | host = "127.0.0.1" ; |
125 | } |
126 | this->Set(host.c_str(), std::stoi(port)); |
127 | } |
128 | |
129 | /*! |
130 | * \brief set the address |
131 | * \param host the url of the address |
132 | * \param port the port of address |
133 | */ |
134 | void Set(const char* host, int port) { |
135 | addrinfo hints; |
136 | memset(&hints, 0, sizeof(hints)); |
137 | hints.ai_family = PF_UNSPEC; |
138 | hints.ai_flags = AI_PASSIVE; |
139 | hints.ai_socktype = SOCK_STREAM; |
140 | addrinfo* res = nullptr; |
141 | int sig = getaddrinfo(host, nullptr, &hints, &res); |
142 | ICHECK(sig == 0 && res != nullptr) << "cannot obtain address of " << host; |
143 | switch (res->ai_family) { |
144 | case AF_INET: { |
145 | sockaddr_in* addr4 = reinterpret_cast<sockaddr_in*>(&addr); |
146 | memcpy(addr4, res->ai_addr, res->ai_addrlen); |
147 | addr4->sin_port = htons(port); |
148 | addr4->sin_family = AF_INET; |
149 | } break; |
150 | case AF_INET6: { |
151 | sockaddr_in6* addr6 = reinterpret_cast<sockaddr_in6*>(&addr); |
152 | memcpy(addr6, res->ai_addr, res->ai_addrlen); |
153 | addr6->sin6_port = htons(port); |
154 | addr6->sin6_family = AF_INET6; |
155 | } break; |
156 | default: |
157 | ICHECK(false) << "cannot decode address" ; |
158 | } |
159 | freeaddrinfo(res); |
160 | } |
161 | /*! \brief return port of the address */ |
162 | int port() const { |
163 | return ntohs((addr.ss_family == AF_INET6) |
164 | ? reinterpret_cast<const sockaddr_in6*>(&addr)->sin6_port |
165 | : reinterpret_cast<const sockaddr_in*>(&addr)->sin_port); |
166 | } |
167 | /*! \brief return the ip address family */ |
168 | int ss_family() const { return addr.ss_family; } |
169 | /*! \return a string representation of the address */ |
170 | std::string AsString() const { |
171 | std::string buf; |
172 | buf.resize(256); |
173 | |
174 | const void* sinx_addr = nullptr; |
175 | if (addr.ss_family == AF_INET6) { |
176 | const in6_addr& addr6 = reinterpret_cast<const sockaddr_in6*>(&addr)->sin6_addr; |
177 | sinx_addr = reinterpret_cast<const void*>(&addr6); |
178 | } else if (addr.ss_family == AF_INET) { |
179 | const in_addr& addr4 = reinterpret_cast<const sockaddr_in*>(&addr)->sin_addr; |
180 | sinx_addr = reinterpret_cast<const void*>(&addr4); |
181 | } else { |
182 | ICHECK(false) << "illegal address" ; |
183 | } |
184 | |
185 | #ifdef _WIN32 |
186 | const char* s = inet_ntop(addr.ss_family, (PVOID)sinx_addr, // NOLINT(*) |
187 | &buf[0], buf.length()); |
188 | #else |
189 | const char* s = |
190 | inet_ntop(addr.ss_family, sinx_addr, &buf[0], static_cast<socklen_t>(buf.length())); |
191 | #endif |
192 | ICHECK(s != nullptr) << "cannot decode address" ; |
193 | std::ostringstream os; |
194 | os << s << ":" << port(); |
195 | return os.str(); |
196 | } |
197 | }; |
198 | /*! |
199 | * \brief base class containing common operations of TCP and UDP sockets |
200 | */ |
201 | class Socket { |
202 | public: |
203 | #if defined(_WIN32) |
204 | using sock_size_t = int; |
205 | using SockType = SOCKET; |
206 | #else |
207 | using SockType = int; |
208 | using sock_size_t = size_t; |
209 | static constexpr int INVALID_SOCKET = -1; |
210 | #endif |
211 | /*! \brief the file descriptor of socket */ |
212 | SockType sockfd; |
213 | /*! |
214 | * \brief set this socket to use non-blocking mode |
215 | * \param non_block whether set it to be non-block, if it is false |
216 | * it will set it back to block mode |
217 | */ |
218 | void SetNonBlock(bool non_block) { |
219 | #ifdef _WIN32 |
220 | u_long mode = non_block ? 1 : 0; |
221 | if (ioctlsocket(sockfd, FIONBIO, &mode) != NO_ERROR) { |
222 | Socket::Error("SetNonBlock" ); |
223 | } |
224 | #else |
225 | int flag = fcntl(sockfd, F_GETFL, 0); |
226 | if (flag == -1) { |
227 | Socket::Error("SetNonBlock-1" ); |
228 | } |
229 | if (non_block) { |
230 | flag |= O_NONBLOCK; |
231 | } else { |
232 | flag &= ~O_NONBLOCK; |
233 | } |
234 | if (fcntl(sockfd, F_SETFL, flag) == -1) { |
235 | Socket::Error("SetNonBlock-2" ); |
236 | } |
237 | #endif |
238 | } |
239 | /*! |
240 | * \brief bind the socket to an address |
241 | * \param addr The address to be binded |
242 | */ |
243 | void Bind(const SockAddr& addr) { |
244 | if (bind(sockfd, reinterpret_cast<const sockaddr*>(&addr.addr), |
245 | (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) : sizeof(sockaddr_in))) == |
246 | -1) { |
247 | Socket::Error("Bind" ); |
248 | } |
249 | } |
250 | /*! |
251 | * \brief try bind the socket to host, from start_port to end_port |
252 | * \param host host address to bind the socket |
253 | * \param start_port starting port number to try |
254 | * \param end_port ending port number to try |
255 | * \return the port successfully bind to, return -1 if failed to bind any port |
256 | */ |
257 | inline int TryBindHost(std::string host, int start_port, int end_port) { |
258 | for (int port = start_port; port < end_port; ++port) { |
259 | SockAddr addr(host.c_str(), port); |
260 | if (bind(sockfd, reinterpret_cast<sockaddr*>(&addr.addr), |
261 | (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) : sizeof(sockaddr_in))) == |
262 | 0) { |
263 | return port; |
264 | } else { |
265 | LOG(WARNING) << "Bind failed to " << host << ":" << port; |
266 | } |
267 | #if defined(_WIN32) |
268 | if (WSAGetLastError() != WSAEADDRINUSE) { |
269 | Socket::Error("TryBindHost" ); |
270 | } |
271 | #else |
272 | if (errno != EADDRINUSE) { |
273 | Socket::Error("TryBindHost" ); |
274 | } |
275 | #endif |
276 | } |
277 | return -1; |
278 | } |
279 | /*! \brief get last error code if any */ |
280 | int GetSockError() const { |
281 | int error = 0; |
282 | socklen_t len = sizeof(error); |
283 | if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, reinterpret_cast<char*>(&error), &len) != 0) { |
284 | Error("GetSockError" ); |
285 | } |
286 | return error; |
287 | } |
288 | /*! \brief check if anything bad happens */ |
289 | bool BadSocket() const { |
290 | if (IsClosed()) return true; |
291 | int err = GetSockError(); |
292 | if (err == EBADF || err == EINTR) return true; |
293 | return false; |
294 | } |
295 | /*! \brief check if socket is already closed */ |
296 | bool IsClosed() const { return sockfd == INVALID_SOCKET; } |
297 | /*! \brief close the socket */ |
298 | void Close() { |
299 | if (sockfd != INVALID_SOCKET) { |
300 | #ifdef _WIN32 |
301 | closesocket(sockfd); |
302 | #else |
303 | close(sockfd); |
304 | #endif |
305 | sockfd = INVALID_SOCKET; |
306 | } else { |
307 | Error("Socket::Close double close the socket or close without create" ); |
308 | } |
309 | } |
310 | /*! |
311 | * \return last error of socket operation |
312 | */ |
313 | static int GetLastError() { |
314 | #ifdef _WIN32 |
315 | return WSAGetLastError(); |
316 | #else |
317 | return errno; |
318 | #endif |
319 | } |
320 | /*! \return whether last error was would block */ |
321 | static bool LastErrorWouldBlock() { |
322 | int errsv = GetLastError(); |
323 | #ifdef _WIN32 |
324 | return errsv == WSAEWOULDBLOCK; |
325 | #else |
326 | return errsv == EAGAIN || errsv == EWOULDBLOCK; |
327 | #endif |
328 | } |
329 | /*! |
330 | * \brief start up the socket module |
331 | * call this before using the sockets |
332 | */ |
333 | static void Startup() { |
334 | #ifdef _WIN32 |
335 | WSADATA wsa_data; |
336 | if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) { |
337 | Socket::Error("Startup" ); |
338 | } |
339 | if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) { |
340 | WSACleanup(); |
341 | LOG(FATAL) << "Could not find a usable version of Winsock.dll" ; |
342 | } |
343 | #endif |
344 | } |
345 | /*! |
346 | * \brief shutdown the socket module after use, all sockets need to be closed |
347 | */ |
348 | static void Finalize() { |
349 | #ifdef _WIN32 |
350 | WSACleanup(); |
351 | #endif |
352 | } |
353 | /*! |
354 | * \brief Report an socket error. |
355 | * \param msg The error message. |
356 | */ |
357 | static void Error(const char* msg) { |
358 | int errsv = GetLastError(); |
359 | #ifdef _WIN32 |
360 | LOG(FATAL) << "Socket " << msg << " Error:WSAError-code=" << errsv; |
361 | #else |
362 | LOG(FATAL) << "Socket " << msg << " Error:" << strerror(errsv); |
363 | #endif |
364 | } |
365 | |
366 | /*! |
367 | * \brief Call a function and retry if an EINTR error is encountered. |
368 | * |
369 | * Socket operations can return EINTR when the interrupt handler |
370 | * is registered by the execution environment(e.g. python). |
371 | * We should retry if there is no KeyboardInterrupt recorded in |
372 | * the environment. |
373 | * |
374 | * \note This function is needed to avoid rare interrupt event |
375 | * in long running server code. |
376 | * |
377 | * \param func The function to retry. |
378 | * \return The return code returned by function f or error_value on retry failure. |
379 | */ |
380 | template <typename FuncType> |
381 | ssize_t RetryCallOnEINTR(FuncType func) { |
382 | ssize_t ret = func(); |
383 | // common path |
384 | if (ret != -1) return ret; |
385 | // less common path |
386 | do { |
387 | if (GetLastError() == EINTR) { |
388 | // Call into env check signals to see if there are |
389 | // environment specific(e.g. python) signal exceptions. |
390 | // This function will throw an exception if there is |
391 | // if the process received a signal that requires TVM to return immediately (e.g. SIGINT). |
392 | runtime::EnvCheckSignals(); |
393 | } else { |
394 | // other errors |
395 | return ret; |
396 | } |
397 | ret = func(); |
398 | } while (ret == -1); |
399 | return ret; |
400 | } |
401 | |
402 | protected: |
403 | explicit Socket(SockType sockfd) : sockfd(sockfd) {} |
404 | }; |
405 | |
406 | /*! |
407 | * \brief a wrapper of TCP socket that hopefully be cross platform |
408 | */ |
409 | class TCPSocket : public Socket { |
410 | public: |
411 | TCPSocket() : Socket(INVALID_SOCKET) {} |
412 | /*! |
413 | * \brief construct a TCP socket from existing descriptor |
414 | * \param sockfd The descriptor |
415 | */ |
416 | explicit TCPSocket(SockType sockfd) : Socket(sockfd) {} |
417 | /*! |
418 | * \brief enable/disable TCP keepalive |
419 | * \param keepalive whether to set the keep alive option on |
420 | */ |
421 | void SetKeepAlive(bool keepalive) { |
422 | int opt = static_cast<int>(keepalive); |
423 | if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast<char*>(&opt), sizeof(opt)) < |
424 | 0) { |
425 | Socket::Error("SetKeepAlive" ); |
426 | } |
427 | } |
428 | /*! |
429 | * \brief create the socket, call this before using socket |
430 | * \param af domain |
431 | */ |
432 | void Create(int af = PF_INET) { |
433 | sockfd = socket(af, SOCK_STREAM, 0); |
434 | if (sockfd == INVALID_SOCKET) { |
435 | Socket::Error("Create" ); |
436 | } |
437 | } |
438 | /*! |
439 | * \brief perform listen of the socket |
440 | * \param backlog backlog parameter |
441 | */ |
442 | void Listen(int backlog = 16) { listen(sockfd, backlog); } |
443 | /*! |
444 | * \brief get a new connection |
445 | * \return The accepted socket connection. |
446 | */ |
447 | TCPSocket Accept() { |
448 | SockType newfd = RetryCallOnEINTR([&]() { return accept(sockfd, nullptr, nullptr); }); |
449 | if (newfd == INVALID_SOCKET) { |
450 | Socket::Error("Accept" ); |
451 | } |
452 | return TCPSocket(newfd); |
453 | } |
454 | /*! |
455 | * \brief get a new connection |
456 | * \param addr client address from which connection accepted |
457 | * \return The accepted socket connection. |
458 | */ |
459 | TCPSocket Accept(SockAddr* addr) { |
460 | socklen_t addrlen = sizeof(addr->addr); |
461 | SockType newfd = RetryCallOnEINTR( |
462 | [&]() { return accept(sockfd, reinterpret_cast<sockaddr*>(&addr->addr), &addrlen); }); |
463 | if (newfd == INVALID_SOCKET) { |
464 | Socket::Error("Accept" ); |
465 | } |
466 | return TCPSocket(newfd); |
467 | } |
468 | /*! |
469 | * \brief decide whether the socket is at OOB mark |
470 | * \return 1 if at mark, 0 if not, -1 if an error occurred |
471 | */ |
472 | int AtMark() const { |
473 | #ifdef _WIN32 |
474 | unsigned long atmark; // NOLINT(*) |
475 | if (ioctlsocket(sockfd, SIOCATMARK, &atmark) != NO_ERROR) return -1; |
476 | #else |
477 | int atmark; |
478 | if (ioctl(sockfd, SIOCATMARK, &atmark) == -1) return -1; |
479 | #endif |
480 | return static_cast<int>(atmark); |
481 | } |
482 | /*! |
483 | * \brief connect to an address |
484 | * \param addr the address to connect to |
485 | * \return whether connect is successful |
486 | */ |
487 | bool Connect(const SockAddr& addr) { |
488 | return connect( |
489 | sockfd, reinterpret_cast<const sockaddr*>(&addr.addr), |
490 | (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) : sizeof(sockaddr_in))) == 0; |
491 | } |
492 | /*! |
493 | * \brief send data using the socket |
494 | * \param buf_ the pointer to the buffer |
495 | * \param len the size of the buffer |
496 | * \param flag extra flags |
497 | * \return size of data actually sent |
498 | * return -1 if error occurs |
499 | */ |
500 | ssize_t Send(const void* buf_, size_t len, int flag = 0) { |
501 | const char* buf = reinterpret_cast<const char*>(buf_); |
502 | return RetryCallOnEINTR( |
503 | [&]() { return send(sockfd, buf, static_cast<sock_size_t>(len), flag); }); |
504 | } |
505 | /*! |
506 | * \brief receive data using the socket |
507 | * \param buf_ the pointer to the buffer |
508 | * \param len the size of the buffer |
509 | * \param flags extra flags |
510 | * \return size of data actually received |
511 | * return -1 if error occurs |
512 | */ |
513 | ssize_t Recv(void* buf_, size_t len, int flags = 0) { |
514 | char* buf = reinterpret_cast<char*>(buf_); |
515 | return RetryCallOnEINTR( |
516 | [&]() { return recv(sockfd, buf, static_cast<sock_size_t>(len), flags); }); |
517 | } |
518 | /*! |
519 | * \brief perform block write that will attempt to send all data out |
520 | * can still return smaller than request when error occurs |
521 | * \param buf_ the pointer to the buffer |
522 | * \param len the size of the buffer |
523 | * \return size of data actually sent |
524 | */ |
525 | size_t SendAll(const void* buf_, size_t len) { |
526 | const char* buf = reinterpret_cast<const char*>(buf_); |
527 | size_t ndone = 0; |
528 | while (ndone < len) { |
529 | ssize_t ret = RetryCallOnEINTR( |
530 | [&]() { return send(sockfd, buf, static_cast<ssize_t>(len - ndone), 0); }); |
531 | if (ret == -1) { |
532 | if (LastErrorWouldBlock()) return ndone; |
533 | Socket::Error("SendAll" ); |
534 | } |
535 | buf += ret; |
536 | ndone += ret; |
537 | } |
538 | return ndone; |
539 | } |
540 | /*! |
541 | * \brief perform block read that will attempt to read all data |
542 | * can still return smaller than request when error occurs |
543 | * \param buf_ the buffer pointer |
544 | * \param len length of data to recv |
545 | * \return size of data actually sent |
546 | */ |
547 | size_t RecvAll(void* buf_, size_t len) { |
548 | char* buf = reinterpret_cast<char*>(buf_); |
549 | size_t ndone = 0; |
550 | while (ndone < len) { |
551 | ssize_t ret = RetryCallOnEINTR( |
552 | [&]() { return recv(sockfd, buf, static_cast<sock_size_t>(len - ndone), MSG_WAITALL); }); |
553 | if (ret == -1) { |
554 | if (LastErrorWouldBlock()) { |
555 | LOG(FATAL) << "would block" ; |
556 | } |
557 | Socket::Error("RecvAll" ); |
558 | } |
559 | if (ret == 0) return ndone; |
560 | buf += ret; |
561 | ndone += ret; |
562 | } |
563 | return ndone; |
564 | } |
565 | /*! |
566 | * \brief Send the data to remote. |
567 | * \param data The data to be sent. |
568 | */ |
569 | void SendBytes(std::string data) { |
570 | int datalen = data.length(); |
571 | ICHECK_EQ(SendAll(&datalen, sizeof(datalen)), sizeof(datalen)); |
572 | ICHECK_EQ(SendAll(data.c_str(), datalen), datalen); |
573 | } |
574 | /*! |
575 | * \brief Receive the data to remote. |
576 | * \return The data received. |
577 | */ |
578 | std::string RecvBytes() { |
579 | int datalen = 0; |
580 | ICHECK_EQ(RecvAll(&datalen, sizeof(datalen)), sizeof(datalen)); |
581 | std::string data; |
582 | data.resize(datalen); |
583 | ICHECK_EQ(RecvAll(&data[0], datalen), datalen); |
584 | return data; |
585 | } |
586 | }; |
587 | |
588 | /*! \brief helper data structure to perform poll */ |
589 | struct PollHelper { |
590 | public: |
591 | /*! |
592 | * \brief add file descriptor to watch for read |
593 | * \param fd file descriptor to be watched |
594 | */ |
595 | inline void WatchRead(TCPSocket::SockType fd) { |
596 | auto& pfd = fds[fd]; |
597 | pfd.fd = fd; |
598 | pfd.events |= POLLIN; |
599 | } |
600 | /*! |
601 | * \brief add file descriptor to watch for write |
602 | * \param fd file descriptor to be watched |
603 | */ |
604 | inline void WatchWrite(TCPSocket::SockType fd) { |
605 | auto& pfd = fds[fd]; |
606 | pfd.fd = fd; |
607 | pfd.events |= POLLOUT; |
608 | } |
609 | /*! |
610 | * \brief add file descriptor to watch for exception |
611 | * \param fd file descriptor to be watched |
612 | */ |
613 | inline void WatchException(TCPSocket::SockType fd) { |
614 | auto& pfd = fds[fd]; |
615 | pfd.fd = fd; |
616 | pfd.events |= POLLPRI; |
617 | } |
618 | /*! |
619 | * \brief Check if the descriptor is ready for read |
620 | * \param fd file descriptor to check status |
621 | */ |
622 | inline bool CheckRead(TCPSocket::SockType fd) const { |
623 | const auto& pfd = fds.find(fd); |
624 | return pfd != fds.end() && ((pfd->second.events & POLLIN) != 0); |
625 | } |
626 | /*! |
627 | * \brief Check if the descriptor is ready for write |
628 | * \param fd file descriptor to check status |
629 | */ |
630 | inline bool CheckWrite(TCPSocket::SockType fd) const { |
631 | const auto& pfd = fds.find(fd); |
632 | return pfd != fds.end() && ((pfd->second.events & POLLOUT) != 0); |
633 | } |
634 | /*! |
635 | * \brief Check if the descriptor has any exception |
636 | * \param fd file descriptor to check status |
637 | */ |
638 | inline bool CheckExcept(TCPSocket::SockType fd) const { |
639 | const auto& pfd = fds.find(fd); |
640 | return pfd != fds.end() && ((pfd->second.events & POLLPRI) != 0); |
641 | } |
642 | /*! |
643 | * \brief wait for exception event on a single descriptor |
644 | * \param fd the file descriptor to wait the event for |
645 | * \param timeout the timeout counter, can be negative, which means wait until the event happen |
646 | * \return 1 if success, 0 if timeout, and -1 if error occurs |
647 | */ |
648 | inline static int WaitExcept(TCPSocket::SockType fd, long timeout = -1) { // NOLINT(*) |
649 | pollfd pfd; |
650 | pfd.fd = fd; |
651 | pfd.events = POLLPRI; |
652 | return poll(&pfd, 1, timeout); |
653 | } |
654 | |
655 | /*! |
656 | * \brief perform poll on the set defined, read, write, exception |
657 | * \param timeout specify timeout in milliseconds(ms) if negative, means poll will block |
658 | * \return |
659 | */ |
660 | inline void Poll(long timeout = -1) { // NOLINT(*) |
661 | std::vector<pollfd> fdset; |
662 | fdset.reserve(fds.size()); |
663 | for (auto kv : fds) { |
664 | fdset.push_back(kv.second); |
665 | } |
666 | int ret = poll(fdset.data(), fdset.size(), timeout); |
667 | if (ret == -1) { |
668 | Socket::Error("Poll" ); |
669 | } else { |
670 | for (auto& pfd : fdset) { |
671 | auto revents = pfd.revents & pfd.events; |
672 | if (!revents) { |
673 | fds.erase(pfd.fd); |
674 | } else { |
675 | fds[pfd.fd].events = revents; |
676 | } |
677 | } |
678 | } |
679 | } |
680 | |
681 | std::unordered_map<TCPSocket::SockType, pollfd> fds; |
682 | }; |
683 | |
684 | } // namespace support |
685 | } // namespace tvm |
686 | #endif // TVM_SUPPORT_SOCKET_H_ |
687 | |