1 | #include <c10/util/irange.h> |
2 | #include <torch/csrc/distributed/c10d/TCPStore.hpp> |
3 | |
4 | #include <fcntl.h> |
5 | #include <algorithm> |
6 | #include <array> |
7 | #include <system_error> |
8 | #include <thread> |
9 | #include <unordered_map> |
10 | #include <utility> |
11 | |
12 | #ifdef _WIN32 |
13 | #include <io.h> |
14 | #include <winsock2.h> |
15 | #else |
16 | #include <poll.h> |
17 | #include <unistd.h> |
18 | #endif |
19 | |
20 | #ifdef _WIN32 |
21 | #include <torch/csrc/distributed/c10d/WinSockUtils.hpp> |
22 | #else |
23 | #include <torch/csrc/distributed/c10d/UnixSockUtils.hpp> |
24 | #endif |
25 | |
26 | #include <torch/csrc/distributed/c10d/socket.h> |
27 | |
28 | namespace c10d { |
29 | namespace detail { |
30 | namespace { |
31 | |
32 | // Abstract base class to handle thread state for TCPStoreMasterDaemon and |
33 | // TCPStoreWorkerDaemon. Contains the windows/unix implementations to signal a |
34 | // shutdown sequence for the thread |
35 | class BackgroundThread { |
36 | public: |
37 | explicit BackgroundThread(Socket&& storeListenSocket); |
38 | |
39 | virtual ~BackgroundThread() = 0; |
40 | |
41 | protected: |
42 | void dispose(); |
43 | |
44 | Socket storeListenSocket_; |
45 | std::thread daemonThread_{}; |
46 | std::vector<Socket> sockets_{}; |
47 | #ifdef _WIN32 |
48 | const std::chrono::milliseconds checkTimeout_ = std::chrono::milliseconds{10}; |
49 | HANDLE ghStopEvent_{}; |
50 | #else |
51 | std::array<int, 2> controlPipeFd_{{-1, -1}}; |
52 | #endif |
53 | |
54 | private: |
55 | // Initialization for shutdown signal |
56 | void initStopSignal(); |
57 | // Triggers the shutdown signal |
58 | void stop(); |
59 | // Joins the thread |
60 | void join(); |
61 | // Clean up the shutdown signal |
62 | void closeStopSignal(); |
63 | }; |
64 | |
65 | // Background thread parent class methods |
66 | BackgroundThread::BackgroundThread(Socket&& storeListenSocket) |
67 | : storeListenSocket_{std::move(storeListenSocket)} { |
68 | // Signal instance destruction to the daemon thread. |
69 | initStopSignal(); |
70 | } |
71 | |
72 | BackgroundThread::~BackgroundThread() = default; |
73 | |
74 | // WARNING: |
75 | // Since we rely on the subclass for the daemon thread clean-up, we cannot |
76 | // destruct our member variables in the destructor. The subclass must call |
77 | // dispose() in its own destructor. |
78 | void BackgroundThread::dispose() { |
79 | // Stop the run |
80 | stop(); |
81 | // Join the thread |
82 | join(); |
83 | // Close unclosed sockets |
84 | sockets_.clear(); |
85 | // Now close the rest control pipe |
86 | closeStopSignal(); |
87 | } |
88 | |
89 | void BackgroundThread::join() { |
90 | daemonThread_.join(); |
91 | } |
92 | |
93 | #ifdef _WIN32 |
94 | void BackgroundThread::initStopSignal() { |
95 | ghStopEvent_ = CreateEvent(NULL, TRUE, FALSE, NULL); |
96 | if (ghStopEvent_ == NULL) { |
97 | TORCH_CHECK( |
98 | false, |
99 | "Failed to create the control pipe to start the " |
100 | "BackgroundThread run" ); |
101 | } |
102 | } |
103 | |
104 | void BackgroundThread::closeStopSignal() { |
105 | CloseHandle(ghStopEvent_); |
106 | } |
107 | |
108 | void BackgroundThread::stop() { |
109 | SetEvent(ghStopEvent_); |
110 | } |
111 | #else |
112 | void BackgroundThread::initStopSignal() { |
113 | if (pipe(controlPipeFd_.data()) == -1) { |
114 | TORCH_CHECK( |
115 | false, |
116 | "Failed to create the control pipe to start the " |
117 | "BackgroundThread run" ); |
118 | } |
119 | } |
120 | |
121 | void BackgroundThread::closeStopSignal() { |
122 | for (int fd : controlPipeFd_) { |
123 | if (fd != -1) { |
124 | ::close(fd); |
125 | } |
126 | } |
127 | } |
128 | |
129 | void BackgroundThread::stop() { |
130 | if (controlPipeFd_[1] != -1) { |
131 | ::write(controlPipeFd_[1], "\0" , 1); |
132 | // close the write end of the pipe |
133 | ::close(controlPipeFd_[1]); |
134 | controlPipeFd_[1] = -1; |
135 | } |
136 | } |
137 | #endif |
138 | |
139 | enum class QueryType : uint8_t { |
140 | SET, |
141 | COMPARE_SET, |
142 | GET, |
143 | ADD, |
144 | CHECK, |
145 | WAIT, |
146 | GETNUMKEYS, |
147 | WATCH_KEY, |
148 | DELETE_KEY, |
149 | }; |
150 | |
151 | enum class CheckResponseType : uint8_t { READY, NOT_READY }; |
152 | |
153 | enum class WaitResponseType : uint8_t { STOP_WAITING }; |
154 | |
155 | enum class WatchResponseType : uint8_t { |
156 | KEY_UPDATED, |
157 | KEY_CREATED, |
158 | KEY_DELETED, |
159 | KEY_CALLBACK_REGISTERED |
160 | }; |
161 | |
162 | // Separate thread that is only launched on master |
163 | class TCPStoreMasterDaemon : public BackgroundThread { |
164 | public: |
165 | explicit TCPStoreMasterDaemon(Socket&& storeListenSocket); |
166 | |
167 | ~TCPStoreMasterDaemon() override; |
168 | |
169 | private: |
170 | void run(); |
171 | void queryFds(std::vector<struct pollfd>& fds); |
172 | void query(int socket); |
173 | |
174 | // The master runs on a single thread so only |
175 | // one handler can be executed at a time |
176 | void setHandler(int socket); |
177 | void compareSetHandler(int socket); |
178 | void addHandler(int socket); |
179 | void getHandler(int socket) const; |
180 | void checkHandler(int socket) const; |
181 | void getNumKeysHandler(int socket) const; |
182 | void deleteHandler(int socket); |
183 | void waitHandler(int socket); |
184 | void watchHandler(int socket); |
185 | |
186 | bool checkKeys(const std::vector<std::string>& keys) const; |
187 | // Helper function to alerts waiting workers, used in setHandler, getHandler |
188 | void wakeupWaitingClients(const std::string& key); |
189 | // Helper function used when the key is changed |
190 | // used in setHandler, addHandler, getHandler, deleteHandler |
191 | void sendKeyUpdatesToClients( |
192 | const std::string& key, |
193 | const enum WatchResponseType& type, |
194 | std::vector<uint8_t>& oldData, |
195 | std::vector<uint8_t>& newData); |
196 | std::unordered_map<std::string, std::vector<uint8_t>> tcpStore_; |
197 | // From key -> the list of sockets waiting on the key |
198 | std::unordered_map<std::string, std::vector<int>> waitingSockets_; |
199 | // From socket -> number of keys awaited |
200 | std::unordered_map<int, size_t> keysAwaited_; |
201 | // From key -> the list of sockets watching the key |
202 | std::unordered_map<std::string, std::vector<int>> watchedSockets_; |
203 | }; |
204 | |
205 | // Simply start the daemon thread |
206 | TCPStoreMasterDaemon::TCPStoreMasterDaemon(Socket&& storeListenSocket) |
207 | : BackgroundThread{std::move(storeListenSocket)} { |
208 | daemonThread_ = std::thread{&TCPStoreMasterDaemon::run, this}; |
209 | } |
210 | |
211 | TCPStoreMasterDaemon::~TCPStoreMasterDaemon() { |
212 | dispose(); |
213 | } |
214 | |
215 | void TCPStoreMasterDaemon::queryFds(std::vector<struct pollfd>& fds) { |
216 | // Skipping the fds[0] and fds[1], |
217 | // fds[0] is master's listening socket |
218 | // fds[1] is control pipe's reading fd, it is not for Windows platform |
219 | for (size_t fdIdx = CONNECT_SOCKET_OFFSET; fdIdx < fds.size(); ++fdIdx) { |
220 | if (fds[fdIdx].revents == 0) { |
221 | continue; |
222 | } |
223 | |
224 | // Now query the socket that has the event |
225 | try { |
226 | query(fds[fdIdx].fd); |
227 | } catch (...) { |
228 | // There was an error when processing query. Probably an exception |
229 | // occurred in recv/send what would indicate that socket on the other |
230 | // side has been closed. If the closing was due to normal exit, then |
231 | // the store should continue executing. Otherwise, if it was different |
232 | // exception, other connections will get an exception once they try to |
233 | // use the store. We will go ahead and close this connection whenever |
234 | // we hit an exception here. |
235 | |
236 | // Remove all the tracking state of the close FD |
237 | for (auto it = waitingSockets_.begin(); it != waitingSockets_.end();) { |
238 | for (auto vecIt = it->second.begin(); vecIt != it->second.end();) { |
239 | if (*vecIt == fds[fdIdx].fd) { |
240 | vecIt = it->second.erase(vecIt); |
241 | } else { |
242 | ++vecIt; |
243 | } |
244 | } |
245 | if (it->second.empty()) { |
246 | it = waitingSockets_.erase(it); |
247 | } else { |
248 | ++it; |
249 | } |
250 | } |
251 | for (auto it = keysAwaited_.begin(); it != keysAwaited_.end();) { |
252 | if (it->first == fds[fdIdx].fd) { |
253 | it = keysAwaited_.erase(it); |
254 | } else { |
255 | ++it; |
256 | } |
257 | } |
258 | fds.erase(fds.begin() + fdIdx); |
259 | sockets_.erase(sockets_.begin() + fdIdx - CONNECT_SOCKET_OFFSET); |
260 | --fdIdx; |
261 | continue; |
262 | } |
263 | } |
264 | } |
265 | |
266 | // query communicates with the worker. The format |
267 | // of the query is as follows: |
268 | // type of query | size of arg1 | arg1 | size of arg2 | arg2 | ... |
269 | // or, in the case of wait |
270 | // type of query | number of args | size of arg1 | arg1 | ... |
271 | void TCPStoreMasterDaemon::query(int socket) { |
272 | QueryType qt; |
273 | tcputil::recvBytes<QueryType>(socket, &qt, 1); |
274 | if (qt == QueryType::SET) { |
275 | setHandler(socket); |
276 | |
277 | } else if (qt == QueryType::COMPARE_SET) { |
278 | compareSetHandler(socket); |
279 | |
280 | } else if (qt == QueryType::ADD) { |
281 | addHandler(socket); |
282 | |
283 | } else if (qt == QueryType::GET) { |
284 | getHandler(socket); |
285 | |
286 | } else if (qt == QueryType::CHECK) { |
287 | checkHandler(socket); |
288 | |
289 | } else if (qt == QueryType::WAIT) { |
290 | waitHandler(socket); |
291 | |
292 | } else if (qt == QueryType::GETNUMKEYS) { |
293 | getNumKeysHandler(socket); |
294 | |
295 | } else if (qt == QueryType::DELETE_KEY) { |
296 | deleteHandler(socket); |
297 | |
298 | } else if (qt == QueryType::WATCH_KEY) { |
299 | watchHandler(socket); |
300 | |
301 | } else { |
302 | TORCH_CHECK(false, "Unexpected query type" ); |
303 | } |
304 | } |
305 | |
306 | void TCPStoreMasterDaemon::wakeupWaitingClients(const std::string& key) { |
307 | auto socketsToWait = waitingSockets_.find(key); |
308 | if (socketsToWait != waitingSockets_.end()) { |
309 | for (int socket : socketsToWait->second) { |
310 | if (--keysAwaited_[socket] == 0) { |
311 | tcputil::sendValue<WaitResponseType>( |
312 | socket, WaitResponseType::STOP_WAITING); |
313 | } |
314 | } |
315 | waitingSockets_.erase(socketsToWait); |
316 | } |
317 | } |
318 | |
319 | void TCPStoreMasterDaemon::sendKeyUpdatesToClients( |
320 | const std::string& key, |
321 | const enum WatchResponseType& type, |
322 | std::vector<uint8_t>& oldData, |
323 | std::vector<uint8_t>& newData) { |
324 | for (int socket : watchedSockets_[key]) { |
325 | tcputil::sendValue<WatchResponseType>(socket, type); |
326 | tcputil::sendString(socket, key, true); |
327 | tcputil::sendVector<uint8_t>(socket, oldData); |
328 | tcputil::sendVector<uint8_t>(socket, newData); |
329 | } |
330 | } |
331 | |
332 | void TCPStoreMasterDaemon::setHandler(int socket) { |
333 | std::string key = tcputil::recvString(socket); |
334 | std::vector<uint8_t> newData = tcputil::recvVector<uint8_t>(socket); |
335 | std::vector<uint8_t> oldData; |
336 | bool newKey = true; |
337 | auto it = tcpStore_.find(key); |
338 | if (it != tcpStore_.end()) { |
339 | oldData = it->second; |
340 | newKey = false; |
341 | } |
342 | tcpStore_[key] = newData; |
343 | // On "set", wake up all clients that have been waiting |
344 | wakeupWaitingClients(key); |
345 | // Send key update to all watching clients |
346 | newKey ? sendKeyUpdatesToClients( |
347 | key, WatchResponseType::KEY_CREATED, oldData, newData) |
348 | : sendKeyUpdatesToClients( |
349 | key, WatchResponseType::KEY_UPDATED, oldData, newData); |
350 | } |
351 | |
352 | void TCPStoreMasterDaemon::compareSetHandler(int socket) { |
353 | std::string key = tcputil::recvString(socket); |
354 | std::vector<uint8_t> currentValue = tcputil::recvVector<uint8_t>(socket); |
355 | std::vector<uint8_t> newValue = tcputil::recvVector<uint8_t>(socket); |
356 | |
357 | auto pos = tcpStore_.find(key); |
358 | if (pos == tcpStore_.end()) { |
359 | if (currentValue.empty()) { |
360 | tcpStore_[key] = newValue; |
361 | |
362 | // Send key update to all watching clients |
363 | sendKeyUpdatesToClients( |
364 | key, WatchResponseType::KEY_CREATED, currentValue, newValue); |
365 | tcputil::sendVector<uint8_t>(socket, newValue); |
366 | } else { |
367 | // TODO: This code path is not ideal as we are "lying" to the caller in |
368 | // case the key does not exist. We should come up with a working solution. |
369 | tcputil::sendVector<uint8_t>(socket, currentValue); |
370 | } |
371 | } else { |
372 | if (pos->second == currentValue) { |
373 | pos->second = std::move(newValue); |
374 | |
375 | // Send key update to all watching clients |
376 | sendKeyUpdatesToClients( |
377 | key, WatchResponseType::KEY_UPDATED, currentValue, pos->second); |
378 | } |
379 | tcputil::sendVector<uint8_t>(socket, pos->second); |
380 | } |
381 | } |
382 | |
383 | void TCPStoreMasterDaemon::addHandler(int socket) { |
384 | std::string key = tcputil::recvString(socket); |
385 | int64_t addVal = tcputil::recvValue<int64_t>(socket); |
386 | |
387 | bool newKey = true; |
388 | std::vector<uint8_t> oldData; |
389 | auto it = tcpStore_.find(key); |
390 | if (it != tcpStore_.end()) { |
391 | oldData = it->second; |
392 | auto buf = reinterpret_cast<const char*>(it->second.data()); |
393 | auto len = it->second.size(); |
394 | addVal += std::stoll(std::string(buf, len)); |
395 | newKey = false; |
396 | } |
397 | auto addValStr = std::to_string(addVal); |
398 | std::vector<uint8_t> newData = |
399 | std::vector<uint8_t>(addValStr.begin(), addValStr.end()); |
400 | tcpStore_[key] = newData; |
401 | // Now send the new value |
402 | tcputil::sendValue<int64_t>(socket, addVal); |
403 | // On "add", wake up all clients that have been waiting |
404 | wakeupWaitingClients(key); |
405 | // Send key update to all watching clients |
406 | newKey ? sendKeyUpdatesToClients( |
407 | key, WatchResponseType::KEY_CREATED, oldData, newData) |
408 | : sendKeyUpdatesToClients( |
409 | key, WatchResponseType::KEY_UPDATED, oldData, newData); |
410 | } |
411 | |
412 | void TCPStoreMasterDaemon::getHandler(int socket) const { |
413 | std::string key = tcputil::recvString(socket); |
414 | auto data = tcpStore_.at(key); |
415 | tcputil::sendVector<uint8_t>(socket, data); |
416 | } |
417 | |
418 | void TCPStoreMasterDaemon::getNumKeysHandler(int socket) const { |
419 | tcputil::sendValue<int64_t>(socket, tcpStore_.size()); |
420 | } |
421 | |
422 | void TCPStoreMasterDaemon::deleteHandler(int socket) { |
423 | std::string key = tcputil::recvString(socket); |
424 | auto it = tcpStore_.find(key); |
425 | if (it != tcpStore_.end()) { |
426 | std::vector<uint8_t> oldData = it->second; |
427 | // Send key update to all watching clients |
428 | std::vector<uint8_t> newData; |
429 | sendKeyUpdatesToClients( |
430 | key, WatchResponseType::KEY_DELETED, oldData, newData); |
431 | } |
432 | auto numDeleted = tcpStore_.erase(key); |
433 | tcputil::sendValue<int64_t>(socket, numDeleted); |
434 | } |
435 | |
436 | void TCPStoreMasterDaemon::checkHandler(int socket) const { |
437 | SizeType nargs = 0; |
438 | tcputil::recvBytes<SizeType>(socket, &nargs, 1); |
439 | std::vector<std::string> keys(nargs); |
440 | for (const auto i : c10::irange(nargs)) { |
441 | keys[i] = tcputil::recvString(socket); |
442 | } |
443 | // Now we have received all the keys |
444 | if (checkKeys(keys)) { |
445 | tcputil::sendValue<CheckResponseType>(socket, CheckResponseType::READY); |
446 | } else { |
447 | tcputil::sendValue<CheckResponseType>(socket, CheckResponseType::NOT_READY); |
448 | } |
449 | } |
450 | |
451 | void TCPStoreMasterDaemon::waitHandler(int socket) { |
452 | SizeType nargs = 0; |
453 | tcputil::recvBytes<SizeType>(socket, &nargs, 1); |
454 | std::vector<std::string> keys(nargs); |
455 | for (const auto i : c10::irange(nargs)) { |
456 | keys[i] = tcputil::recvString(socket); |
457 | } |
458 | if (checkKeys(keys)) { |
459 | tcputil::sendValue<WaitResponseType>( |
460 | socket, WaitResponseType::STOP_WAITING); |
461 | } else { |
462 | int numKeysToAwait = 0; |
463 | for (auto& key : keys) { |
464 | // Only count keys that have not already been set |
465 | if (tcpStore_.find(key) == tcpStore_.end()) { |
466 | waitingSockets_[key].push_back(socket); |
467 | numKeysToAwait++; |
468 | } |
469 | } |
470 | keysAwaited_[socket] = numKeysToAwait; |
471 | } |
472 | } |
473 | |
474 | void TCPStoreMasterDaemon::watchHandler(int socket) { |
475 | std::string key = tcputil::recvString(socket); |
476 | |
477 | // Record the socket to respond to when the key is updated |
478 | watchedSockets_[key].push_back(socket); |
479 | |
480 | // Send update to TCPStoreWorkerDaemon on client |
481 | tcputil::sendValue<WatchResponseType>( |
482 | socket, WatchResponseType::KEY_CALLBACK_REGISTERED); |
483 | } |
484 | |
485 | bool TCPStoreMasterDaemon::checkKeys( |
486 | const std::vector<std::string>& keys) const { |
487 | return std::all_of(keys.begin(), keys.end(), [this](const std::string& s) { |
488 | return tcpStore_.count(s) > 0; |
489 | }); |
490 | } |
491 | |
492 | #ifdef _WIN32 |
493 | void TCPStoreMasterDaemon::run() { |
494 | std::vector<struct pollfd> fds; |
495 | tcputil::addPollfd(fds, storeListenSocket_.handle(), POLLIN); |
496 | |
497 | // receive the queries |
498 | bool finished = false; |
499 | while (!finished) { |
500 | for (const auto i : c10::irange(sockets_.size())) { |
501 | fds[i].revents = 0; |
502 | } |
503 | |
504 | int res; |
505 | SYSCHECK_ERR_RETURN_NEG1( |
506 | res = WSAPoll(fds.data(), fds.size(), checkTimeout_.count())) |
507 | if (res == 0) { |
508 | auto rv = WaitForSingleObject(ghStopEvent_, 0); |
509 | if (rv != WAIT_TIMEOUT) { |
510 | finished = true; |
511 | break; |
512 | } |
513 | continue; |
514 | } |
515 | |
516 | // TCPStore's listening socket has an event and it should now be able to |
517 | // accept new connections. |
518 | if (fds[0].revents != 0) { |
519 | if (!(fds[0].revents & POLLIN)) { |
520 | throw std::system_error( |
521 | ECONNABORTED, |
522 | std::system_category(), |
523 | "Unexpected poll revent on the master's listening socket: " + |
524 | std::to_string(fds[0].revents)); |
525 | } |
526 | Socket socket = storeListenSocket_.accept(); |
527 | int rawSocket = socket.handle(); |
528 | sockets_.emplace_back(std::move(socket)); |
529 | tcputil::addPollfd(fds, rawSocket, POLLIN); |
530 | } |
531 | queryFds(fds); |
532 | } |
533 | } |
534 | #else |
535 | void TCPStoreMasterDaemon::run() { |
536 | std::vector<struct pollfd> fds; |
537 | tcputil::addPollfd(fds, storeListenSocket_.handle(), POLLIN); |
538 | // Although we haven't found any documentation or literature describing this, |
539 | // we've seen cases that, under certain circumstances, the read end of the |
540 | // pipe won't receive POLLHUP when the write end is closed. However, under |
541 | // the same circumstances, writing to the pipe will guarantee POLLIN to be |
542 | // received on the read end. |
543 | // |
544 | // For more reliable termination, the main thread will write a byte to the |
545 | // pipe before closing it, and the background thread will poll for both |
546 | // POLLIN and POLLHUP. |
547 | tcputil::addPollfd(fds, controlPipeFd_[0], POLLIN | POLLHUP); |
548 | |
549 | // receive the queries |
550 | bool finished = false; |
551 | while (!finished) { |
552 | for (const auto i : c10::irange(sockets_.size())) { |
553 | fds[i].revents = 0; |
554 | } |
555 | |
556 | SYSCHECK_ERR_RETURN_NEG1(::poll(fds.data(), fds.size(), -1)); |
557 | |
558 | // TCPStore's listening socket has an event and it should now be able to |
559 | // accept new connections. |
560 | if (fds[0].revents != 0) { |
561 | if (fds[0].revents ^ POLLIN) { |
562 | throw std::system_error( |
563 | ECONNABORTED, |
564 | std::system_category(), |
565 | "Unexpected poll revent on the master's listening socket: " + |
566 | std::to_string(fds[0].revents)); |
567 | } |
568 | Socket socket = storeListenSocket_.accept(); |
569 | int rawSocket = socket.handle(); |
570 | sockets_.emplace_back(std::move(socket)); |
571 | tcputil::addPollfd(fds, rawSocket, POLLIN); |
572 | } |
573 | |
574 | // The pipe receives an event which tells us to shutdown the daemon |
575 | if (fds[1].revents != 0) { |
576 | // The main thread will write a byte to the pipe then close it before |
577 | // joining the background thread |
578 | if (fds[1].revents & ~(POLLIN | POLLHUP)) { |
579 | throw std::system_error( |
580 | ECONNABORTED, |
581 | std::system_category(), |
582 | "Unexpected poll revent on the control pipe's reading fd: " + |
583 | std::to_string(fds[1].revents)); |
584 | } |
585 | finished = true; |
586 | break; |
587 | } |
588 | queryFds(fds); |
589 | } |
590 | } |
591 | #endif |
592 | |
593 | // Separate thread that is launched on all instances (including master) |
594 | // Right now only handles callbacks registered from watchKey() |
595 | class TCPStoreWorkerDaemon : public BackgroundThread { |
596 | public: |
597 | explicit TCPStoreWorkerDaemon(Socket&& listenSocket); |
598 | ~TCPStoreWorkerDaemon() override; |
599 | // Set the callback to run key change |
600 | void setCallback(std::string key, WatchKeyCallback cb); |
601 | void waitForCallbackRegistration() { |
602 | // Block until callback has been registered successfully |
603 | std::unique_lock<std::mutex> callbackRegistrationLock( |
604 | callbackRegistrationMutex_); |
605 | callbackRegisteredCV_.wait( |
606 | callbackRegistrationLock, [&] { return callbackRegisteredData_; }); |
607 | |
608 | // Reset payload for next callback |
609 | callbackRegisteredData_ = false; |
610 | } |
611 | void setCallbackRegistered() { |
612 | { |
613 | std::unique_lock<std::mutex> callbackRegistrationLock( |
614 | callbackRegistrationMutex_); |
615 | callbackRegisteredData_ = true; |
616 | } |
617 | callbackRegisteredCV_.notify_one(); |
618 | } |
619 | |
620 | private: |
621 | void run(); |
622 | void callbackHandler(int socket); |
623 | // List of callbacks map each watched key |
624 | std::unordered_map<std::string, WatchKeyCallback> keyToCallbacks_{}; |
625 | std::mutex keyToCallbacksMutex_{}; |
626 | std::mutex callbackRegistrationMutex_{}; |
627 | std::condition_variable callbackRegisteredCV_{}; |
628 | bool callbackRegisteredData_ = false; |
629 | }; |
630 | |
631 | // TCPStoreListener class methods |
632 | TCPStoreWorkerDaemon::TCPStoreWorkerDaemon(Socket&& listenSocket) |
633 | : BackgroundThread{std::move(listenSocket)} { |
634 | daemonThread_ = std::thread{&TCPStoreWorkerDaemon::run, this}; |
635 | } |
636 | |
637 | TCPStoreWorkerDaemon::~TCPStoreWorkerDaemon() { |
638 | dispose(); |
639 | } |
640 | |
641 | void TCPStoreWorkerDaemon::setCallback( |
642 | std::string key, |
643 | WatchKeyCallback callback) { |
644 | const std::lock_guard<std::mutex> lock(keyToCallbacksMutex_); |
645 | keyToCallbacks_[key] = callback; |
646 | } |
647 | |
648 | // Runs all the callbacks that the worker has registered |
649 | void TCPStoreWorkerDaemon::callbackHandler(int socket) { |
650 | auto watchResponse = tcputil::recvValue<WatchResponseType>(socket); |
651 | if (watchResponse == WatchResponseType::KEY_CALLBACK_REGISTERED) { |
652 | // Notify the waiting "watchKey" operation to return |
653 | setCallbackRegistered(); |
654 | return; |
655 | } |
656 | std::string key = tcputil::recvString(socket); |
657 | std::vector<uint8_t> currentValueVec = tcputil::recvVector<uint8_t>(socket); |
658 | std::vector<uint8_t> newValueVec = tcputil::recvVector<uint8_t>(socket); |
659 | c10::optional<std::string> currentValue; |
660 | if (watchResponse == WatchResponseType::KEY_CREATED) { |
661 | assert(currentValueVec.empty()); |
662 | currentValue = c10::nullopt; |
663 | } else { |
664 | currentValue = std::string(currentValueVec.begin(), currentValueVec.end()); |
665 | } |
666 | c10::optional<std::string> newValue; |
667 | if (watchResponse == WatchResponseType::KEY_DELETED) { |
668 | assert(newValueVec.empty()); |
669 | newValue = c10::nullopt; |
670 | } else { |
671 | newValue = std::string(newValueVec.begin(), newValueVec.end()); |
672 | } |
673 | const std::lock_guard<std::mutex> lock(keyToCallbacksMutex_); |
674 | keyToCallbacks_.at(key)(currentValue, newValue); |
675 | } |
676 | |
677 | #ifdef _WIN32 |
678 | void TCPStoreWorkerDaemon::run() { |
679 | std::vector<struct pollfd> fds; |
680 | tcputil::addPollfd(fds, storeListenSocket_.handle(), POLLIN); |
681 | |
682 | while (true) { |
683 | // Check control and exit early if triggered |
684 | int res; |
685 | SYSCHECK_ERR_RETURN_NEG1( |
686 | res = WSAPoll(fds.data(), fds.size(), checkTimeout_.count())) |
687 | if (res == 0) { |
688 | auto rvPoll = WaitForSingleObject(ghStopEvent_, 0); |
689 | if (rvPoll != WAIT_TIMEOUT) { |
690 | break; |
691 | } |
692 | continue; |
693 | } |
694 | |
695 | // if connection is closed gracefully by master, peeked data will return 0 |
696 | char data; |
697 | int ret = recv(fds[0].fd, &data, 1, MSG_PEEK); |
698 | if (ret == 0) { |
699 | auto rvData = WaitForSingleObject(ghStopEvent_, 0); |
700 | if (rvData != WAIT_TIMEOUT) { |
701 | break; |
702 | } |
703 | continue; |
704 | } |
705 | |
706 | // valid request, perform callback logic |
707 | callbackHandler(fds[0].fd); |
708 | } |
709 | } |
710 | #else |
711 | void TCPStoreWorkerDaemon::run() { |
712 | std::vector<struct pollfd> fds; |
713 | // Although we haven't found any documentation or literature describing this, |
714 | // we've seen cases that, under certain circumstances, the read end of the |
715 | // pipe won't receive POLLHUP when the write end is closed. However, under |
716 | // the same circumstances, writing to the pipe will guarantee POLLIN to be |
717 | // received on the read end. |
718 | // |
719 | // For more reliable termination, the main thread will write a byte to the |
720 | // pipe before closing it, and the background thread will poll for both |
721 | // POLLIN and POLLHUP. |
722 | tcputil::addPollfd(fds, controlPipeFd_[0], POLLIN | POLLHUP); |
723 | tcputil::addPollfd(fds, storeListenSocket_.handle(), POLLIN); |
724 | |
725 | while (true) { |
726 | SYSCHECK_ERR_RETURN_NEG1(::poll(fds.data(), fds.size(), -1)); |
727 | |
728 | // Check control and exit early if triggered |
729 | // The pipe receives an event which tells us to shutdown the listener thread |
730 | if (fds[0].revents != 0) { |
731 | // The main thread will write a byte to the pipe then close it before |
732 | // joining the background thread |
733 | if (fds[0].revents & ~(POLLIN | POLLHUP)) { |
734 | throw std::system_error( |
735 | ECONNABORTED, |
736 | std::system_category(), |
737 | "Unexpected poll revent on the control pipe's reading fd: " + |
738 | std::to_string(fds[0].revents)); |
739 | } |
740 | break; |
741 | } |
742 | |
743 | // if connection is closed gracefully by master, peeked data will return 0 |
744 | char data = 0; |
745 | int ret = recv(fds[1].fd, &data, 1, MSG_PEEK); |
746 | if (ret == 0) { |
747 | continue; |
748 | } |
749 | |
750 | // valid request, perform callback logic |
751 | callbackHandler(fds[1].fd); |
752 | } |
753 | } |
754 | #endif |
755 | |
756 | } // namespace |
757 | |
758 | // Manages the lifecycle of a server daemon. |
759 | class TCPServer { |
760 | public: |
761 | static std::shared_ptr<TCPServer> start(const TCPStoreOptions& opts); |
762 | |
763 | std::uint16_t port() const noexcept { |
764 | return port_; |
765 | } |
766 | |
767 | explicit TCPServer( |
768 | std::uint16_t port, |
769 | std::unique_ptr<TCPStoreMasterDaemon>&& daemon) |
770 | : port_{port}, daemon_{std::move(daemon)} {} |
771 | |
772 | private: |
773 | std::uint16_t port_; |
774 | std::unique_ptr<TCPStoreMasterDaemon> daemon_; |
775 | |
776 | // We store weak references to all TCPServers for which the caller requested |
777 | // multi-tenancy. |
778 | static std::unordered_map<std::uint16_t, std::weak_ptr<TCPServer>> |
779 | cachedServers_; |
780 | |
781 | static std::mutex cache_mutex_; |
782 | }; |
783 | |
784 | std::unordered_map<std::uint16_t, std::weak_ptr<TCPServer>> |
785 | TCPServer::cachedServers_{}; |
786 | |
787 | std::mutex TCPServer::cache_mutex_{}; |
788 | |
789 | std::shared_ptr<TCPServer> TCPServer::start(const TCPStoreOptions& opts) { |
790 | auto startCore = [&opts]() { |
791 | Socket socket = Socket::listen(opts.port); |
792 | |
793 | std::uint16_t port = socket.port(); |
794 | |
795 | auto daemon = std::make_unique<TCPStoreMasterDaemon>(std::move(socket)); |
796 | |
797 | return std::make_shared<TCPServer>(port, std::move(daemon)); |
798 | }; |
799 | |
800 | std::shared_ptr<TCPServer> server{}; |
801 | |
802 | if (opts.multiTenant) { |
803 | std::lock_guard<std::mutex> guard{cache_mutex_}; |
804 | |
805 | // If the caller is okay with a multi-tenant store, first check if we |
806 | // already have a TCPServer running on the specified port. |
807 | if (opts.port > 0) { |
808 | auto pos = cachedServers_.find(opts.port); |
809 | if (pos != cachedServers_.end()) { |
810 | server = pos->second.lock(); |
811 | if (server != nullptr) { |
812 | return server; |
813 | } |
814 | |
815 | // Looks like the TCPStore has been disposed, make sure that we release |
816 | // the control block. |
817 | cachedServers_.erase(pos); |
818 | } |
819 | } |
820 | |
821 | server = startCore(); |
822 | |
823 | cachedServers_.emplace(server->port(), server); |
824 | } else { |
825 | server = startCore(); |
826 | } |
827 | |
828 | return server; |
829 | } |
830 | |
831 | class TCPClient { |
832 | public: |
833 | static std::unique_ptr<TCPClient> connect( |
834 | const SocketAddress& addr, |
835 | const TCPStoreOptions& opts); |
836 | |
837 | void sendCommand(QueryType type) { |
838 | tcputil::sendValue<QueryType>(socket_.handle(), type); |
839 | } |
840 | |
841 | void sendCommandForKey(QueryType type, const std::string& key); |
842 | |
843 | void sendBytes(const std::vector<std::uint8_t>& value) { |
844 | tcputil::sendVector<std::uint8_t>(socket_.handle(), value); |
845 | } |
846 | |
847 | void sendStrings(c10::ArrayRef<std::string> value); |
848 | |
849 | template <typename T> |
850 | void sendValue(const T& value) { |
851 | tcputil::sendValue<T>(socket_.handle(), value); |
852 | } |
853 | |
854 | std::vector<std::uint8_t> receiveBits() { |
855 | return tcputil::recvVector<std::uint8_t>(socket_.handle()); |
856 | } |
857 | |
858 | template <typename T> |
859 | T receiveValue() { |
860 | return tcputil::recvValue<T>(socket_.handle()); |
861 | } |
862 | |
863 | void setTimeout(std::chrono::milliseconds value); |
864 | |
865 | explicit TCPClient(Socket&& socket) : socket_{std::move(socket)} {} |
866 | |
867 | private: |
868 | Socket socket_; |
869 | }; |
870 | |
871 | std::unique_ptr<TCPClient> TCPClient::connect( |
872 | const SocketAddress& addr, |
873 | const TCPStoreOptions& opts) { |
874 | auto timeout = std::chrono::duration_cast<std::chrono::seconds>(opts.timeout); |
875 | Socket socket = Socket::connect( |
876 | addr.host, addr.port, SocketOptions{}.connect_timeout(timeout)); |
877 | |
878 | return std::make_unique<TCPClient>(std::move(socket)); |
879 | } |
880 | |
881 | void TCPClient::sendCommandForKey(QueryType type, const std::string& key) { |
882 | tcputil::sendValue<QueryType>(socket_.handle(), type); |
883 | |
884 | bool withValue = type == QueryType::SET || type == QueryType::COMPARE_SET || |
885 | type == QueryType::ADD; |
886 | |
887 | tcputil::sendString(socket_.handle(), key, withValue); |
888 | } |
889 | |
890 | void TCPClient::sendStrings(c10::ArrayRef<std::string> value) { |
891 | std::size_t size = value.size(); |
892 | |
893 | tcputil::sendBytes<std::size_t>(socket_.handle(), &size, 1, size > 0); |
894 | |
895 | if (value.empty()) { |
896 | return; |
897 | } |
898 | |
899 | for (auto pos = value.begin(), last = value.end() - 1; pos <= last; ++pos) { |
900 | tcputil::sendString(socket_.handle(), *pos, pos != last); |
901 | } |
902 | } |
903 | |
904 | void TCPClient::setTimeout(std::chrono::milliseconds value) { |
905 | if (value == std::chrono::milliseconds::zero()) { |
906 | return; |
907 | } |
908 | |
909 | #ifdef _WIN32 |
910 | struct timeval timeoutTV = { |
911 | static_cast<long>(value.count() / 1000), |
912 | static_cast<long>((value.count() % 1000) * 1000)}; |
913 | #else |
914 | struct timeval timeoutTV = { |
915 | .tv_sec = value.count() / 1000, |
916 | .tv_usec = static_cast<suseconds_t>((value.count() % 1000) * 1000), |
917 | }; |
918 | #endif |
919 | SYSCHECK_ERR_RETURN_NEG1(::setsockopt( |
920 | socket_.handle(), |
921 | SOL_SOCKET, |
922 | SO_RCVTIMEO, |
923 | reinterpret_cast<char*>(&timeoutTV), |
924 | sizeof(timeoutTV))); |
925 | } |
926 | |
927 | class TCPCallbackClient { |
928 | public: |
929 | static std::unique_ptr<TCPCallbackClient> connect( |
930 | const SocketAddress& addr, |
931 | const TCPStoreOptions& opts); |
932 | |
933 | void setCallback(const std::string& key, WatchKeyCallback callback); |
934 | |
935 | explicit TCPCallbackClient( |
936 | int rawSocket, |
937 | std::unique_ptr<TCPStoreWorkerDaemon>&& daemon) |
938 | : rawSocket_{rawSocket}, daemon_{std::move(daemon)} {} |
939 | |
940 | private: |
941 | int rawSocket_; |
942 | std::unique_ptr<TCPStoreWorkerDaemon> daemon_; |
943 | std::mutex mutex_; |
944 | }; |
945 | |
946 | std::unique_ptr<TCPCallbackClient> TCPCallbackClient::connect( |
947 | const SocketAddress& addr, |
948 | const TCPStoreOptions& opts) { |
949 | auto timeout = std::chrono::duration_cast<std::chrono::seconds>(opts.timeout); |
950 | Socket socket = Socket::connect( |
951 | addr.host, addr.port, SocketOptions{}.connect_timeout(timeout)); |
952 | |
953 | int rawSocket = socket.handle(); |
954 | |
955 | auto daemon = std::make_unique<TCPStoreWorkerDaemon>(std::move(socket)); |
956 | |
957 | return std::make_unique<TCPCallbackClient>(rawSocket, std::move(daemon)); |
958 | } |
959 | |
960 | void TCPCallbackClient::setCallback( |
961 | const std::string& key, |
962 | WatchKeyCallback callback) { |
963 | std::lock_guard<std::mutex> guard{mutex_}; |
964 | |
965 | daemon_->setCallback(key, callback); |
966 | |
967 | tcputil::sendValue<QueryType>(rawSocket_, QueryType::WATCH_KEY); |
968 | |
969 | tcputil::sendString(rawSocket_, key); |
970 | |
971 | daemon_->waitForCallbackRegistration(); |
972 | } |
973 | |
974 | } // namespace detail |
975 | |
976 | using detail::Socket; |
977 | |
978 | // TCPStore class methods |
979 | TCPStore::TCPStore( |
980 | const std::string& masterAddr, |
981 | std::uint16_t masterPort, |
982 | c10::optional<int> numWorkers, |
983 | bool isServer, |
984 | const std::chrono::milliseconds& timeout, |
985 | bool waitWorkers) |
986 | : TCPStore{ |
987 | masterAddr, |
988 | TCPStoreOptions{ |
989 | masterPort, |
990 | isServer, |
991 | numWorkers ? c10::optional<std::size_t>(*numWorkers) |
992 | : c10::nullopt, |
993 | waitWorkers, |
994 | timeout}} {} |
995 | |
996 | TCPStore::TCPStore(std::string host, const TCPStoreOptions& opts) |
997 | : Store{opts.timeout}, |
998 | addr_{std::move(host)}, |
999 | numWorkers_{opts.numWorkers} { |
1000 | Socket::initialize(); |
1001 | |
1002 | if (opts.isServer) { |
1003 | server_ = detail::TCPServer::start(opts); |
1004 | |
1005 | addr_.port = server_->port(); |
1006 | } else { |
1007 | addr_.port = opts.port; |
1008 | } |
1009 | |
1010 | client_ = detail::TCPClient::connect(addr_, opts); |
1011 | |
1012 | if (opts.waitWorkers) { |
1013 | waitForWorkers(); |
1014 | } |
1015 | |
1016 | callbackClient_ = detail::TCPCallbackClient::connect(addr_, opts); |
1017 | } |
1018 | |
1019 | TCPStore::~TCPStore() = default; |
1020 | |
1021 | void TCPStore::waitForWorkers() { |
1022 | if (numWorkers_ == c10::nullopt) { |
1023 | return; |
1024 | } |
1025 | |
1026 | incrementValueBy(initKey_, 1); |
1027 | |
1028 | // Let server block until all workers have completed, this ensures that |
1029 | // the server daemon thread is always running until the very end |
1030 | if (server_) { |
1031 | const auto start = std::chrono::steady_clock::now(); |
1032 | while (true) { |
1033 | // TODO: Any chance to make this cleaner? |
1034 | std::vector<uint8_t> value = doGet(initKey_); |
1035 | auto buf = reinterpret_cast<const char*>(value.data()); |
1036 | auto len = value.size(); |
1037 | int numWorkersCompleted = std::stoi(std::string(buf, len)); |
1038 | if (numWorkersCompleted >= *numWorkers_) { |
1039 | break; |
1040 | } |
1041 | const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>( |
1042 | std::chrono::steady_clock::now() - start); |
1043 | if (timeout_ != kNoTimeout && elapsed > timeout_) { |
1044 | break; |
1045 | } |
1046 | /* sleep override */ |
1047 | std::this_thread::sleep_for(std::chrono::milliseconds(10)); |
1048 | } |
1049 | } |
1050 | } |
1051 | |
1052 | void TCPStore::set(const std::string& key, const std::vector<uint8_t>& data) { |
1053 | const std::lock_guard<std::mutex> lock(activeOpLock_); |
1054 | client_->sendCommandForKey(detail::QueryType::SET, keyPrefix_ + key); |
1055 | client_->sendBytes(data); |
1056 | } |
1057 | |
1058 | std::vector<uint8_t> TCPStore::compareSet( |
1059 | const std::string& key, |
1060 | const std::vector<uint8_t>& expectedValue, |
1061 | const std::vector<uint8_t>& desiredValue) { |
1062 | const std::lock_guard<std::mutex> lock(activeOpLock_); |
1063 | client_->sendCommandForKey(detail::QueryType::COMPARE_SET, keyPrefix_ + key); |
1064 | client_->sendBytes(expectedValue); |
1065 | client_->sendBytes(desiredValue); |
1066 | |
1067 | return client_->receiveBits(); |
1068 | } |
1069 | |
1070 | std::vector<uint8_t> TCPStore::get(const std::string& key) { |
1071 | const std::lock_guard<std::mutex> lock(activeOpLock_); |
1072 | return doGet(keyPrefix_ + key); |
1073 | } |
1074 | |
1075 | std::vector<uint8_t> TCPStore::doGet(const std::string& key) { |
1076 | doWait(key, timeout_); |
1077 | client_->sendCommandForKey(detail::QueryType::GET, key); |
1078 | return client_->receiveBits(); |
1079 | } |
1080 | |
1081 | int64_t TCPStore::add(const std::string& key, int64_t value) { |
1082 | const std::lock_guard<std::mutex> lock(activeOpLock_); |
1083 | return incrementValueBy(keyPrefix_ + key, value); |
1084 | } |
1085 | |
1086 | bool TCPStore::deleteKey(const std::string& key) { |
1087 | const std::lock_guard<std::mutex> lock(activeOpLock_); |
1088 | client_->sendCommandForKey(detail::QueryType::DELETE_KEY, keyPrefix_ + key); |
1089 | auto numDeleted = client_->receiveValue<std::int64_t>(); |
1090 | return numDeleted == 1; |
1091 | } |
1092 | |
1093 | void TCPStore::watchKey(const std::string& key, WatchKeyCallback callback) { |
1094 | const std::lock_guard<std::mutex> lock(activeOpLock_); |
1095 | callbackClient_->setCallback(keyPrefix_ + key, callback); |
1096 | } |
1097 | |
1098 | int64_t TCPStore::incrementValueBy(const std::string& key, int64_t delta) { |
1099 | client_->sendCommandForKey(detail::QueryType::ADD, key); |
1100 | client_->sendValue<std::int64_t>(delta); |
1101 | return client_->receiveValue<std::int64_t>(); |
1102 | } |
1103 | |
1104 | int64_t TCPStore::getNumKeys() { |
1105 | const std::lock_guard<std::mutex> lock(activeOpLock_); |
1106 | client_->sendCommand(detail::QueryType::GETNUMKEYS); |
1107 | return client_->receiveValue<std::int64_t>(); |
1108 | } |
1109 | |
1110 | bool TCPStore::check(const std::vector<std::string>& keys) { |
1111 | const std::lock_guard<std::mutex> lock(activeOpLock_); |
1112 | std::vector<std::string> prefixedKeys{}; |
1113 | prefixedKeys.reserve(keys.size()); |
1114 | for (const std::string& key : keys) { |
1115 | prefixedKeys.emplace_back(keyPrefix_ + key); |
1116 | } |
1117 | |
1118 | client_->sendCommand(detail::QueryType::CHECK); |
1119 | client_->sendStrings(prefixedKeys); |
1120 | |
1121 | auto response = client_->receiveValue<detail::CheckResponseType>(); |
1122 | if (response == detail::CheckResponseType::READY) { |
1123 | return true; |
1124 | } |
1125 | if (response == detail::CheckResponseType::NOT_READY) { |
1126 | return false; |
1127 | } |
1128 | TORCH_CHECK(false, "ready or not_ready response expected" ); |
1129 | } |
1130 | |
1131 | void TCPStore::wait(const std::vector<std::string>& keys) { |
1132 | wait(keys, timeout_); |
1133 | } |
1134 | |
1135 | void TCPStore::wait( |
1136 | const std::vector<std::string>& keys, |
1137 | const std::chrono::milliseconds& timeout) { |
1138 | const std::lock_guard<std::mutex> lock(activeOpLock_); |
1139 | std::vector<std::string> prefixedKeys{}; |
1140 | prefixedKeys.reserve(keys.size()); |
1141 | for (const std::string& key : keys) { |
1142 | prefixedKeys.emplace_back(keyPrefix_ + key); |
1143 | } |
1144 | |
1145 | doWait(prefixedKeys, timeout); |
1146 | } |
1147 | |
1148 | void TCPStore::doWait( |
1149 | c10::ArrayRef<std::string> keys, |
1150 | std::chrono::milliseconds timeout) { |
1151 | // TODO: Should we revert to the original timeout at the end of the call? |
1152 | client_->setTimeout(timeout); |
1153 | |
1154 | client_->sendCommand(detail::QueryType::WAIT); |
1155 | client_->sendStrings(keys); |
1156 | |
1157 | auto response = client_->receiveValue<detail::WaitResponseType>(); |
1158 | if (response != detail::WaitResponseType::STOP_WAITING) { |
1159 | TORCH_CHECK(false, "Stop_waiting response is expected" ); |
1160 | } |
1161 | } |
1162 | |
1163 | } // namespace c10d |
1164 | |