1#include <c10/util/irange.h>
2#include <torch/csrc/distributed/c10d/TCPStore.hpp>
3
4#include <fcntl.h>
5#include <algorithm>
6#include <array>
7#include <system_error>
8#include <thread>
9#include <unordered_map>
10#include <utility>
11
12#ifdef _WIN32
13#include <io.h>
14#include <winsock2.h>
15#else
16#include <poll.h>
17#include <unistd.h>
18#endif
19
20#ifdef _WIN32
21#include <torch/csrc/distributed/c10d/WinSockUtils.hpp>
22#else
23#include <torch/csrc/distributed/c10d/UnixSockUtils.hpp>
24#endif
25
26#include <torch/csrc/distributed/c10d/socket.h>
27
28namespace c10d {
29namespace detail {
30namespace {
31
32// Abstract base class to handle thread state for TCPStoreMasterDaemon and
33// TCPStoreWorkerDaemon. Contains the windows/unix implementations to signal a
34// shutdown sequence for the thread
35class BackgroundThread {
36 public:
37 explicit BackgroundThread(Socket&& storeListenSocket);
38
39 virtual ~BackgroundThread() = 0;
40
41 protected:
42 void dispose();
43
44 Socket storeListenSocket_;
45 std::thread daemonThread_{};
46 std::vector<Socket> sockets_{};
47#ifdef _WIN32
48 const std::chrono::milliseconds checkTimeout_ = std::chrono::milliseconds{10};
49 HANDLE ghStopEvent_{};
50#else
51 std::array<int, 2> controlPipeFd_{{-1, -1}};
52#endif
53
54 private:
55 // Initialization for shutdown signal
56 void initStopSignal();
57 // Triggers the shutdown signal
58 void stop();
59 // Joins the thread
60 void join();
61 // Clean up the shutdown signal
62 void closeStopSignal();
63};
64
65// Background thread parent class methods
66BackgroundThread::BackgroundThread(Socket&& storeListenSocket)
67 : storeListenSocket_{std::move(storeListenSocket)} {
68 // Signal instance destruction to the daemon thread.
69 initStopSignal();
70}
71
72BackgroundThread::~BackgroundThread() = default;
73
74// WARNING:
75// Since we rely on the subclass for the daemon thread clean-up, we cannot
76// destruct our member variables in the destructor. The subclass must call
77// dispose() in its own destructor.
78void BackgroundThread::dispose() {
79 // Stop the run
80 stop();
81 // Join the thread
82 join();
83 // Close unclosed sockets
84 sockets_.clear();
85 // Now close the rest control pipe
86 closeStopSignal();
87}
88
89void BackgroundThread::join() {
90 daemonThread_.join();
91}
92
93#ifdef _WIN32
94void BackgroundThread::initStopSignal() {
95 ghStopEvent_ = CreateEvent(NULL, TRUE, FALSE, NULL);
96 if (ghStopEvent_ == NULL) {
97 TORCH_CHECK(
98 false,
99 "Failed to create the control pipe to start the "
100 "BackgroundThread run");
101 }
102}
103
104void BackgroundThread::closeStopSignal() {
105 CloseHandle(ghStopEvent_);
106}
107
108void BackgroundThread::stop() {
109 SetEvent(ghStopEvent_);
110}
111#else
112void BackgroundThread::initStopSignal() {
113 if (pipe(controlPipeFd_.data()) == -1) {
114 TORCH_CHECK(
115 false,
116 "Failed to create the control pipe to start the "
117 "BackgroundThread run");
118 }
119}
120
121void BackgroundThread::closeStopSignal() {
122 for (int fd : controlPipeFd_) {
123 if (fd != -1) {
124 ::close(fd);
125 }
126 }
127}
128
129void BackgroundThread::stop() {
130 if (controlPipeFd_[1] != -1) {
131 ::write(controlPipeFd_[1], "\0", 1);
132 // close the write end of the pipe
133 ::close(controlPipeFd_[1]);
134 controlPipeFd_[1] = -1;
135 }
136}
137#endif
138
139enum class QueryType : uint8_t {
140 SET,
141 COMPARE_SET,
142 GET,
143 ADD,
144 CHECK,
145 WAIT,
146 GETNUMKEYS,
147 WATCH_KEY,
148 DELETE_KEY,
149};
150
151enum class CheckResponseType : uint8_t { READY, NOT_READY };
152
153enum class WaitResponseType : uint8_t { STOP_WAITING };
154
155enum class WatchResponseType : uint8_t {
156 KEY_UPDATED,
157 KEY_CREATED,
158 KEY_DELETED,
159 KEY_CALLBACK_REGISTERED
160};
161
162// Separate thread that is only launched on master
163class TCPStoreMasterDaemon : public BackgroundThread {
164 public:
165 explicit TCPStoreMasterDaemon(Socket&& storeListenSocket);
166
167 ~TCPStoreMasterDaemon() override;
168
169 private:
170 void run();
171 void queryFds(std::vector<struct pollfd>& fds);
172 void query(int socket);
173
174 // The master runs on a single thread so only
175 // one handler can be executed at a time
176 void setHandler(int socket);
177 void compareSetHandler(int socket);
178 void addHandler(int socket);
179 void getHandler(int socket) const;
180 void checkHandler(int socket) const;
181 void getNumKeysHandler(int socket) const;
182 void deleteHandler(int socket);
183 void waitHandler(int socket);
184 void watchHandler(int socket);
185
186 bool checkKeys(const std::vector<std::string>& keys) const;
187 // Helper function to alerts waiting workers, used in setHandler, getHandler
188 void wakeupWaitingClients(const std::string& key);
189 // Helper function used when the key is changed
190 // used in setHandler, addHandler, getHandler, deleteHandler
191 void sendKeyUpdatesToClients(
192 const std::string& key,
193 const enum WatchResponseType& type,
194 std::vector<uint8_t>& oldData,
195 std::vector<uint8_t>& newData);
196 std::unordered_map<std::string, std::vector<uint8_t>> tcpStore_;
197 // From key -> the list of sockets waiting on the key
198 std::unordered_map<std::string, std::vector<int>> waitingSockets_;
199 // From socket -> number of keys awaited
200 std::unordered_map<int, size_t> keysAwaited_;
201 // From key -> the list of sockets watching the key
202 std::unordered_map<std::string, std::vector<int>> watchedSockets_;
203};
204
205// Simply start the daemon thread
206TCPStoreMasterDaemon::TCPStoreMasterDaemon(Socket&& storeListenSocket)
207 : BackgroundThread{std::move(storeListenSocket)} {
208 daemonThread_ = std::thread{&TCPStoreMasterDaemon::run, this};
209}
210
211TCPStoreMasterDaemon::~TCPStoreMasterDaemon() {
212 dispose();
213}
214
215void TCPStoreMasterDaemon::queryFds(std::vector<struct pollfd>& fds) {
216 // Skipping the fds[0] and fds[1],
217 // fds[0] is master's listening socket
218 // fds[1] is control pipe's reading fd, it is not for Windows platform
219 for (size_t fdIdx = CONNECT_SOCKET_OFFSET; fdIdx < fds.size(); ++fdIdx) {
220 if (fds[fdIdx].revents == 0) {
221 continue;
222 }
223
224 // Now query the socket that has the event
225 try {
226 query(fds[fdIdx].fd);
227 } catch (...) {
228 // There was an error when processing query. Probably an exception
229 // occurred in recv/send what would indicate that socket on the other
230 // side has been closed. If the closing was due to normal exit, then
231 // the store should continue executing. Otherwise, if it was different
232 // exception, other connections will get an exception once they try to
233 // use the store. We will go ahead and close this connection whenever
234 // we hit an exception here.
235
236 // Remove all the tracking state of the close FD
237 for (auto it = waitingSockets_.begin(); it != waitingSockets_.end();) {
238 for (auto vecIt = it->second.begin(); vecIt != it->second.end();) {
239 if (*vecIt == fds[fdIdx].fd) {
240 vecIt = it->second.erase(vecIt);
241 } else {
242 ++vecIt;
243 }
244 }
245 if (it->second.empty()) {
246 it = waitingSockets_.erase(it);
247 } else {
248 ++it;
249 }
250 }
251 for (auto it = keysAwaited_.begin(); it != keysAwaited_.end();) {
252 if (it->first == fds[fdIdx].fd) {
253 it = keysAwaited_.erase(it);
254 } else {
255 ++it;
256 }
257 }
258 fds.erase(fds.begin() + fdIdx);
259 sockets_.erase(sockets_.begin() + fdIdx - CONNECT_SOCKET_OFFSET);
260 --fdIdx;
261 continue;
262 }
263 }
264}
265
266// query communicates with the worker. The format
267// of the query is as follows:
268// type of query | size of arg1 | arg1 | size of arg2 | arg2 | ...
269// or, in the case of wait
270// type of query | number of args | size of arg1 | arg1 | ...
271void TCPStoreMasterDaemon::query(int socket) {
272 QueryType qt;
273 tcputil::recvBytes<QueryType>(socket, &qt, 1);
274 if (qt == QueryType::SET) {
275 setHandler(socket);
276
277 } else if (qt == QueryType::COMPARE_SET) {
278 compareSetHandler(socket);
279
280 } else if (qt == QueryType::ADD) {
281 addHandler(socket);
282
283 } else if (qt == QueryType::GET) {
284 getHandler(socket);
285
286 } else if (qt == QueryType::CHECK) {
287 checkHandler(socket);
288
289 } else if (qt == QueryType::WAIT) {
290 waitHandler(socket);
291
292 } else if (qt == QueryType::GETNUMKEYS) {
293 getNumKeysHandler(socket);
294
295 } else if (qt == QueryType::DELETE_KEY) {
296 deleteHandler(socket);
297
298 } else if (qt == QueryType::WATCH_KEY) {
299 watchHandler(socket);
300
301 } else {
302 TORCH_CHECK(false, "Unexpected query type");
303 }
304}
305
306void TCPStoreMasterDaemon::wakeupWaitingClients(const std::string& key) {
307 auto socketsToWait = waitingSockets_.find(key);
308 if (socketsToWait != waitingSockets_.end()) {
309 for (int socket : socketsToWait->second) {
310 if (--keysAwaited_[socket] == 0) {
311 tcputil::sendValue<WaitResponseType>(
312 socket, WaitResponseType::STOP_WAITING);
313 }
314 }
315 waitingSockets_.erase(socketsToWait);
316 }
317}
318
319void TCPStoreMasterDaemon::sendKeyUpdatesToClients(
320 const std::string& key,
321 const enum WatchResponseType& type,
322 std::vector<uint8_t>& oldData,
323 std::vector<uint8_t>& newData) {
324 for (int socket : watchedSockets_[key]) {
325 tcputil::sendValue<WatchResponseType>(socket, type);
326 tcputil::sendString(socket, key, true);
327 tcputil::sendVector<uint8_t>(socket, oldData);
328 tcputil::sendVector<uint8_t>(socket, newData);
329 }
330}
331
332void TCPStoreMasterDaemon::setHandler(int socket) {
333 std::string key = tcputil::recvString(socket);
334 std::vector<uint8_t> newData = tcputil::recvVector<uint8_t>(socket);
335 std::vector<uint8_t> oldData;
336 bool newKey = true;
337 auto it = tcpStore_.find(key);
338 if (it != tcpStore_.end()) {
339 oldData = it->second;
340 newKey = false;
341 }
342 tcpStore_[key] = newData;
343 // On "set", wake up all clients that have been waiting
344 wakeupWaitingClients(key);
345 // Send key update to all watching clients
346 newKey ? sendKeyUpdatesToClients(
347 key, WatchResponseType::KEY_CREATED, oldData, newData)
348 : sendKeyUpdatesToClients(
349 key, WatchResponseType::KEY_UPDATED, oldData, newData);
350}
351
352void TCPStoreMasterDaemon::compareSetHandler(int socket) {
353 std::string key = tcputil::recvString(socket);
354 std::vector<uint8_t> currentValue = tcputil::recvVector<uint8_t>(socket);
355 std::vector<uint8_t> newValue = tcputil::recvVector<uint8_t>(socket);
356
357 auto pos = tcpStore_.find(key);
358 if (pos == tcpStore_.end()) {
359 if (currentValue.empty()) {
360 tcpStore_[key] = newValue;
361
362 // Send key update to all watching clients
363 sendKeyUpdatesToClients(
364 key, WatchResponseType::KEY_CREATED, currentValue, newValue);
365 tcputil::sendVector<uint8_t>(socket, newValue);
366 } else {
367 // TODO: This code path is not ideal as we are "lying" to the caller in
368 // case the key does not exist. We should come up with a working solution.
369 tcputil::sendVector<uint8_t>(socket, currentValue);
370 }
371 } else {
372 if (pos->second == currentValue) {
373 pos->second = std::move(newValue);
374
375 // Send key update to all watching clients
376 sendKeyUpdatesToClients(
377 key, WatchResponseType::KEY_UPDATED, currentValue, pos->second);
378 }
379 tcputil::sendVector<uint8_t>(socket, pos->second);
380 }
381}
382
383void TCPStoreMasterDaemon::addHandler(int socket) {
384 std::string key = tcputil::recvString(socket);
385 int64_t addVal = tcputil::recvValue<int64_t>(socket);
386
387 bool newKey = true;
388 std::vector<uint8_t> oldData;
389 auto it = tcpStore_.find(key);
390 if (it != tcpStore_.end()) {
391 oldData = it->second;
392 auto buf = reinterpret_cast<const char*>(it->second.data());
393 auto len = it->second.size();
394 addVal += std::stoll(std::string(buf, len));
395 newKey = false;
396 }
397 auto addValStr = std::to_string(addVal);
398 std::vector<uint8_t> newData =
399 std::vector<uint8_t>(addValStr.begin(), addValStr.end());
400 tcpStore_[key] = newData;
401 // Now send the new value
402 tcputil::sendValue<int64_t>(socket, addVal);
403 // On "add", wake up all clients that have been waiting
404 wakeupWaitingClients(key);
405 // Send key update to all watching clients
406 newKey ? sendKeyUpdatesToClients(
407 key, WatchResponseType::KEY_CREATED, oldData, newData)
408 : sendKeyUpdatesToClients(
409 key, WatchResponseType::KEY_UPDATED, oldData, newData);
410}
411
412void TCPStoreMasterDaemon::getHandler(int socket) const {
413 std::string key = tcputil::recvString(socket);
414 auto data = tcpStore_.at(key);
415 tcputil::sendVector<uint8_t>(socket, data);
416}
417
418void TCPStoreMasterDaemon::getNumKeysHandler(int socket) const {
419 tcputil::sendValue<int64_t>(socket, tcpStore_.size());
420}
421
422void TCPStoreMasterDaemon::deleteHandler(int socket) {
423 std::string key = tcputil::recvString(socket);
424 auto it = tcpStore_.find(key);
425 if (it != tcpStore_.end()) {
426 std::vector<uint8_t> oldData = it->second;
427 // Send key update to all watching clients
428 std::vector<uint8_t> newData;
429 sendKeyUpdatesToClients(
430 key, WatchResponseType::KEY_DELETED, oldData, newData);
431 }
432 auto numDeleted = tcpStore_.erase(key);
433 tcputil::sendValue<int64_t>(socket, numDeleted);
434}
435
436void TCPStoreMasterDaemon::checkHandler(int socket) const {
437 SizeType nargs = 0;
438 tcputil::recvBytes<SizeType>(socket, &nargs, 1);
439 std::vector<std::string> keys(nargs);
440 for (const auto i : c10::irange(nargs)) {
441 keys[i] = tcputil::recvString(socket);
442 }
443 // Now we have received all the keys
444 if (checkKeys(keys)) {
445 tcputil::sendValue<CheckResponseType>(socket, CheckResponseType::READY);
446 } else {
447 tcputil::sendValue<CheckResponseType>(socket, CheckResponseType::NOT_READY);
448 }
449}
450
451void TCPStoreMasterDaemon::waitHandler(int socket) {
452 SizeType nargs = 0;
453 tcputil::recvBytes<SizeType>(socket, &nargs, 1);
454 std::vector<std::string> keys(nargs);
455 for (const auto i : c10::irange(nargs)) {
456 keys[i] = tcputil::recvString(socket);
457 }
458 if (checkKeys(keys)) {
459 tcputil::sendValue<WaitResponseType>(
460 socket, WaitResponseType::STOP_WAITING);
461 } else {
462 int numKeysToAwait = 0;
463 for (auto& key : keys) {
464 // Only count keys that have not already been set
465 if (tcpStore_.find(key) == tcpStore_.end()) {
466 waitingSockets_[key].push_back(socket);
467 numKeysToAwait++;
468 }
469 }
470 keysAwaited_[socket] = numKeysToAwait;
471 }
472}
473
474void TCPStoreMasterDaemon::watchHandler(int socket) {
475 std::string key = tcputil::recvString(socket);
476
477 // Record the socket to respond to when the key is updated
478 watchedSockets_[key].push_back(socket);
479
480 // Send update to TCPStoreWorkerDaemon on client
481 tcputil::sendValue<WatchResponseType>(
482 socket, WatchResponseType::KEY_CALLBACK_REGISTERED);
483}
484
485bool TCPStoreMasterDaemon::checkKeys(
486 const std::vector<std::string>& keys) const {
487 return std::all_of(keys.begin(), keys.end(), [this](const std::string& s) {
488 return tcpStore_.count(s) > 0;
489 });
490}
491
492#ifdef _WIN32
493void TCPStoreMasterDaemon::run() {
494 std::vector<struct pollfd> fds;
495 tcputil::addPollfd(fds, storeListenSocket_.handle(), POLLIN);
496
497 // receive the queries
498 bool finished = false;
499 while (!finished) {
500 for (const auto i : c10::irange(sockets_.size())) {
501 fds[i].revents = 0;
502 }
503
504 int res;
505 SYSCHECK_ERR_RETURN_NEG1(
506 res = WSAPoll(fds.data(), fds.size(), checkTimeout_.count()))
507 if (res == 0) {
508 auto rv = WaitForSingleObject(ghStopEvent_, 0);
509 if (rv != WAIT_TIMEOUT) {
510 finished = true;
511 break;
512 }
513 continue;
514 }
515
516 // TCPStore's listening socket has an event and it should now be able to
517 // accept new connections.
518 if (fds[0].revents != 0) {
519 if (!(fds[0].revents & POLLIN)) {
520 throw std::system_error(
521 ECONNABORTED,
522 std::system_category(),
523 "Unexpected poll revent on the master's listening socket: " +
524 std::to_string(fds[0].revents));
525 }
526 Socket socket = storeListenSocket_.accept();
527 int rawSocket = socket.handle();
528 sockets_.emplace_back(std::move(socket));
529 tcputil::addPollfd(fds, rawSocket, POLLIN);
530 }
531 queryFds(fds);
532 }
533}
534#else
535void TCPStoreMasterDaemon::run() {
536 std::vector<struct pollfd> fds;
537 tcputil::addPollfd(fds, storeListenSocket_.handle(), POLLIN);
538 // Although we haven't found any documentation or literature describing this,
539 // we've seen cases that, under certain circumstances, the read end of the
540 // pipe won't receive POLLHUP when the write end is closed. However, under
541 // the same circumstances, writing to the pipe will guarantee POLLIN to be
542 // received on the read end.
543 //
544 // For more reliable termination, the main thread will write a byte to the
545 // pipe before closing it, and the background thread will poll for both
546 // POLLIN and POLLHUP.
547 tcputil::addPollfd(fds, controlPipeFd_[0], POLLIN | POLLHUP);
548
549 // receive the queries
550 bool finished = false;
551 while (!finished) {
552 for (const auto i : c10::irange(sockets_.size())) {
553 fds[i].revents = 0;
554 }
555
556 SYSCHECK_ERR_RETURN_NEG1(::poll(fds.data(), fds.size(), -1));
557
558 // TCPStore's listening socket has an event and it should now be able to
559 // accept new connections.
560 if (fds[0].revents != 0) {
561 if (fds[0].revents ^ POLLIN) {
562 throw std::system_error(
563 ECONNABORTED,
564 std::system_category(),
565 "Unexpected poll revent on the master's listening socket: " +
566 std::to_string(fds[0].revents));
567 }
568 Socket socket = storeListenSocket_.accept();
569 int rawSocket = socket.handle();
570 sockets_.emplace_back(std::move(socket));
571 tcputil::addPollfd(fds, rawSocket, POLLIN);
572 }
573
574 // The pipe receives an event which tells us to shutdown the daemon
575 if (fds[1].revents != 0) {
576 // The main thread will write a byte to the pipe then close it before
577 // joining the background thread
578 if (fds[1].revents & ~(POLLIN | POLLHUP)) {
579 throw std::system_error(
580 ECONNABORTED,
581 std::system_category(),
582 "Unexpected poll revent on the control pipe's reading fd: " +
583 std::to_string(fds[1].revents));
584 }
585 finished = true;
586 break;
587 }
588 queryFds(fds);
589 }
590}
591#endif
592
593// Separate thread that is launched on all instances (including master)
594// Right now only handles callbacks registered from watchKey()
595class TCPStoreWorkerDaemon : public BackgroundThread {
596 public:
597 explicit TCPStoreWorkerDaemon(Socket&& listenSocket);
598 ~TCPStoreWorkerDaemon() override;
599 // Set the callback to run key change
600 void setCallback(std::string key, WatchKeyCallback cb);
601 void waitForCallbackRegistration() {
602 // Block until callback has been registered successfully
603 std::unique_lock<std::mutex> callbackRegistrationLock(
604 callbackRegistrationMutex_);
605 callbackRegisteredCV_.wait(
606 callbackRegistrationLock, [&] { return callbackRegisteredData_; });
607
608 // Reset payload for next callback
609 callbackRegisteredData_ = false;
610 }
611 void setCallbackRegistered() {
612 {
613 std::unique_lock<std::mutex> callbackRegistrationLock(
614 callbackRegistrationMutex_);
615 callbackRegisteredData_ = true;
616 }
617 callbackRegisteredCV_.notify_one();
618 }
619
620 private:
621 void run();
622 void callbackHandler(int socket);
623 // List of callbacks map each watched key
624 std::unordered_map<std::string, WatchKeyCallback> keyToCallbacks_{};
625 std::mutex keyToCallbacksMutex_{};
626 std::mutex callbackRegistrationMutex_{};
627 std::condition_variable callbackRegisteredCV_{};
628 bool callbackRegisteredData_ = false;
629};
630
631// TCPStoreListener class methods
632TCPStoreWorkerDaemon::TCPStoreWorkerDaemon(Socket&& listenSocket)
633 : BackgroundThread{std::move(listenSocket)} {
634 daemonThread_ = std::thread{&TCPStoreWorkerDaemon::run, this};
635}
636
637TCPStoreWorkerDaemon::~TCPStoreWorkerDaemon() {
638 dispose();
639}
640
641void TCPStoreWorkerDaemon::setCallback(
642 std::string key,
643 WatchKeyCallback callback) {
644 const std::lock_guard<std::mutex> lock(keyToCallbacksMutex_);
645 keyToCallbacks_[key] = callback;
646}
647
648// Runs all the callbacks that the worker has registered
649void TCPStoreWorkerDaemon::callbackHandler(int socket) {
650 auto watchResponse = tcputil::recvValue<WatchResponseType>(socket);
651 if (watchResponse == WatchResponseType::KEY_CALLBACK_REGISTERED) {
652 // Notify the waiting "watchKey" operation to return
653 setCallbackRegistered();
654 return;
655 }
656 std::string key = tcputil::recvString(socket);
657 std::vector<uint8_t> currentValueVec = tcputil::recvVector<uint8_t>(socket);
658 std::vector<uint8_t> newValueVec = tcputil::recvVector<uint8_t>(socket);
659 c10::optional<std::string> currentValue;
660 if (watchResponse == WatchResponseType::KEY_CREATED) {
661 assert(currentValueVec.empty());
662 currentValue = c10::nullopt;
663 } else {
664 currentValue = std::string(currentValueVec.begin(), currentValueVec.end());
665 }
666 c10::optional<std::string> newValue;
667 if (watchResponse == WatchResponseType::KEY_DELETED) {
668 assert(newValueVec.empty());
669 newValue = c10::nullopt;
670 } else {
671 newValue = std::string(newValueVec.begin(), newValueVec.end());
672 }
673 const std::lock_guard<std::mutex> lock(keyToCallbacksMutex_);
674 keyToCallbacks_.at(key)(currentValue, newValue);
675}
676
677#ifdef _WIN32
678void TCPStoreWorkerDaemon::run() {
679 std::vector<struct pollfd> fds;
680 tcputil::addPollfd(fds, storeListenSocket_.handle(), POLLIN);
681
682 while (true) {
683 // Check control and exit early if triggered
684 int res;
685 SYSCHECK_ERR_RETURN_NEG1(
686 res = WSAPoll(fds.data(), fds.size(), checkTimeout_.count()))
687 if (res == 0) {
688 auto rvPoll = WaitForSingleObject(ghStopEvent_, 0);
689 if (rvPoll != WAIT_TIMEOUT) {
690 break;
691 }
692 continue;
693 }
694
695 // if connection is closed gracefully by master, peeked data will return 0
696 char data;
697 int ret = recv(fds[0].fd, &data, 1, MSG_PEEK);
698 if (ret == 0) {
699 auto rvData = WaitForSingleObject(ghStopEvent_, 0);
700 if (rvData != WAIT_TIMEOUT) {
701 break;
702 }
703 continue;
704 }
705
706 // valid request, perform callback logic
707 callbackHandler(fds[0].fd);
708 }
709}
710#else
711void TCPStoreWorkerDaemon::run() {
712 std::vector<struct pollfd> fds;
713 // Although we haven't found any documentation or literature describing this,
714 // we've seen cases that, under certain circumstances, the read end of the
715 // pipe won't receive POLLHUP when the write end is closed. However, under
716 // the same circumstances, writing to the pipe will guarantee POLLIN to be
717 // received on the read end.
718 //
719 // For more reliable termination, the main thread will write a byte to the
720 // pipe before closing it, and the background thread will poll for both
721 // POLLIN and POLLHUP.
722 tcputil::addPollfd(fds, controlPipeFd_[0], POLLIN | POLLHUP);
723 tcputil::addPollfd(fds, storeListenSocket_.handle(), POLLIN);
724
725 while (true) {
726 SYSCHECK_ERR_RETURN_NEG1(::poll(fds.data(), fds.size(), -1));
727
728 // Check control and exit early if triggered
729 // The pipe receives an event which tells us to shutdown the listener thread
730 if (fds[0].revents != 0) {
731 // The main thread will write a byte to the pipe then close it before
732 // joining the background thread
733 if (fds[0].revents & ~(POLLIN | POLLHUP)) {
734 throw std::system_error(
735 ECONNABORTED,
736 std::system_category(),
737 "Unexpected poll revent on the control pipe's reading fd: " +
738 std::to_string(fds[0].revents));
739 }
740 break;
741 }
742
743 // if connection is closed gracefully by master, peeked data will return 0
744 char data = 0;
745 int ret = recv(fds[1].fd, &data, 1, MSG_PEEK);
746 if (ret == 0) {
747 continue;
748 }
749
750 // valid request, perform callback logic
751 callbackHandler(fds[1].fd);
752 }
753}
754#endif
755
756} // namespace
757
758// Manages the lifecycle of a server daemon.
759class TCPServer {
760 public:
761 static std::shared_ptr<TCPServer> start(const TCPStoreOptions& opts);
762
763 std::uint16_t port() const noexcept {
764 return port_;
765 }
766
767 explicit TCPServer(
768 std::uint16_t port,
769 std::unique_ptr<TCPStoreMasterDaemon>&& daemon)
770 : port_{port}, daemon_{std::move(daemon)} {}
771
772 private:
773 std::uint16_t port_;
774 std::unique_ptr<TCPStoreMasterDaemon> daemon_;
775
776 // We store weak references to all TCPServers for which the caller requested
777 // multi-tenancy.
778 static std::unordered_map<std::uint16_t, std::weak_ptr<TCPServer>>
779 cachedServers_;
780
781 static std::mutex cache_mutex_;
782};
783
784std::unordered_map<std::uint16_t, std::weak_ptr<TCPServer>>
785 TCPServer::cachedServers_{};
786
787std::mutex TCPServer::cache_mutex_{};
788
789std::shared_ptr<TCPServer> TCPServer::start(const TCPStoreOptions& opts) {
790 auto startCore = [&opts]() {
791 Socket socket = Socket::listen(opts.port);
792
793 std::uint16_t port = socket.port();
794
795 auto daemon = std::make_unique<TCPStoreMasterDaemon>(std::move(socket));
796
797 return std::make_shared<TCPServer>(port, std::move(daemon));
798 };
799
800 std::shared_ptr<TCPServer> server{};
801
802 if (opts.multiTenant) {
803 std::lock_guard<std::mutex> guard{cache_mutex_};
804
805 // If the caller is okay with a multi-tenant store, first check if we
806 // already have a TCPServer running on the specified port.
807 if (opts.port > 0) {
808 auto pos = cachedServers_.find(opts.port);
809 if (pos != cachedServers_.end()) {
810 server = pos->second.lock();
811 if (server != nullptr) {
812 return server;
813 }
814
815 // Looks like the TCPStore has been disposed, make sure that we release
816 // the control block.
817 cachedServers_.erase(pos);
818 }
819 }
820
821 server = startCore();
822
823 cachedServers_.emplace(server->port(), server);
824 } else {
825 server = startCore();
826 }
827
828 return server;
829}
830
831class TCPClient {
832 public:
833 static std::unique_ptr<TCPClient> connect(
834 const SocketAddress& addr,
835 const TCPStoreOptions& opts);
836
837 void sendCommand(QueryType type) {
838 tcputil::sendValue<QueryType>(socket_.handle(), type);
839 }
840
841 void sendCommandForKey(QueryType type, const std::string& key);
842
843 void sendBytes(const std::vector<std::uint8_t>& value) {
844 tcputil::sendVector<std::uint8_t>(socket_.handle(), value);
845 }
846
847 void sendStrings(c10::ArrayRef<std::string> value);
848
849 template <typename T>
850 void sendValue(const T& value) {
851 tcputil::sendValue<T>(socket_.handle(), value);
852 }
853
854 std::vector<std::uint8_t> receiveBits() {
855 return tcputil::recvVector<std::uint8_t>(socket_.handle());
856 }
857
858 template <typename T>
859 T receiveValue() {
860 return tcputil::recvValue<T>(socket_.handle());
861 }
862
863 void setTimeout(std::chrono::milliseconds value);
864
865 explicit TCPClient(Socket&& socket) : socket_{std::move(socket)} {}
866
867 private:
868 Socket socket_;
869};
870
871std::unique_ptr<TCPClient> TCPClient::connect(
872 const SocketAddress& addr,
873 const TCPStoreOptions& opts) {
874 auto timeout = std::chrono::duration_cast<std::chrono::seconds>(opts.timeout);
875 Socket socket = Socket::connect(
876 addr.host, addr.port, SocketOptions{}.connect_timeout(timeout));
877
878 return std::make_unique<TCPClient>(std::move(socket));
879}
880
881void TCPClient::sendCommandForKey(QueryType type, const std::string& key) {
882 tcputil::sendValue<QueryType>(socket_.handle(), type);
883
884 bool withValue = type == QueryType::SET || type == QueryType::COMPARE_SET ||
885 type == QueryType::ADD;
886
887 tcputil::sendString(socket_.handle(), key, withValue);
888}
889
890void TCPClient::sendStrings(c10::ArrayRef<std::string> value) {
891 std::size_t size = value.size();
892
893 tcputil::sendBytes<std::size_t>(socket_.handle(), &size, 1, size > 0);
894
895 if (value.empty()) {
896 return;
897 }
898
899 for (auto pos = value.begin(), last = value.end() - 1; pos <= last; ++pos) {
900 tcputil::sendString(socket_.handle(), *pos, pos != last);
901 }
902}
903
904void TCPClient::setTimeout(std::chrono::milliseconds value) {
905 if (value == std::chrono::milliseconds::zero()) {
906 return;
907 }
908
909#ifdef _WIN32
910 struct timeval timeoutTV = {
911 static_cast<long>(value.count() / 1000),
912 static_cast<long>((value.count() % 1000) * 1000)};
913#else
914 struct timeval timeoutTV = {
915 .tv_sec = value.count() / 1000,
916 .tv_usec = static_cast<suseconds_t>((value.count() % 1000) * 1000),
917 };
918#endif
919 SYSCHECK_ERR_RETURN_NEG1(::setsockopt(
920 socket_.handle(),
921 SOL_SOCKET,
922 SO_RCVTIMEO,
923 reinterpret_cast<char*>(&timeoutTV),
924 sizeof(timeoutTV)));
925}
926
927class TCPCallbackClient {
928 public:
929 static std::unique_ptr<TCPCallbackClient> connect(
930 const SocketAddress& addr,
931 const TCPStoreOptions& opts);
932
933 void setCallback(const std::string& key, WatchKeyCallback callback);
934
935 explicit TCPCallbackClient(
936 int rawSocket,
937 std::unique_ptr<TCPStoreWorkerDaemon>&& daemon)
938 : rawSocket_{rawSocket}, daemon_{std::move(daemon)} {}
939
940 private:
941 int rawSocket_;
942 std::unique_ptr<TCPStoreWorkerDaemon> daemon_;
943 std::mutex mutex_;
944};
945
946std::unique_ptr<TCPCallbackClient> TCPCallbackClient::connect(
947 const SocketAddress& addr,
948 const TCPStoreOptions& opts) {
949 auto timeout = std::chrono::duration_cast<std::chrono::seconds>(opts.timeout);
950 Socket socket = Socket::connect(
951 addr.host, addr.port, SocketOptions{}.connect_timeout(timeout));
952
953 int rawSocket = socket.handle();
954
955 auto daemon = std::make_unique<TCPStoreWorkerDaemon>(std::move(socket));
956
957 return std::make_unique<TCPCallbackClient>(rawSocket, std::move(daemon));
958}
959
960void TCPCallbackClient::setCallback(
961 const std::string& key,
962 WatchKeyCallback callback) {
963 std::lock_guard<std::mutex> guard{mutex_};
964
965 daemon_->setCallback(key, callback);
966
967 tcputil::sendValue<QueryType>(rawSocket_, QueryType::WATCH_KEY);
968
969 tcputil::sendString(rawSocket_, key);
970
971 daemon_->waitForCallbackRegistration();
972}
973
974} // namespace detail
975
976using detail::Socket;
977
978// TCPStore class methods
979TCPStore::TCPStore(
980 const std::string& masterAddr,
981 std::uint16_t masterPort,
982 c10::optional<int> numWorkers,
983 bool isServer,
984 const std::chrono::milliseconds& timeout,
985 bool waitWorkers)
986 : TCPStore{
987 masterAddr,
988 TCPStoreOptions{
989 masterPort,
990 isServer,
991 numWorkers ? c10::optional<std::size_t>(*numWorkers)
992 : c10::nullopt,
993 waitWorkers,
994 timeout}} {}
995
996TCPStore::TCPStore(std::string host, const TCPStoreOptions& opts)
997 : Store{opts.timeout},
998 addr_{std::move(host)},
999 numWorkers_{opts.numWorkers} {
1000 Socket::initialize();
1001
1002 if (opts.isServer) {
1003 server_ = detail::TCPServer::start(opts);
1004
1005 addr_.port = server_->port();
1006 } else {
1007 addr_.port = opts.port;
1008 }
1009
1010 client_ = detail::TCPClient::connect(addr_, opts);
1011
1012 if (opts.waitWorkers) {
1013 waitForWorkers();
1014 }
1015
1016 callbackClient_ = detail::TCPCallbackClient::connect(addr_, opts);
1017}
1018
1019TCPStore::~TCPStore() = default;
1020
1021void TCPStore::waitForWorkers() {
1022 if (numWorkers_ == c10::nullopt) {
1023 return;
1024 }
1025
1026 incrementValueBy(initKey_, 1);
1027
1028 // Let server block until all workers have completed, this ensures that
1029 // the server daemon thread is always running until the very end
1030 if (server_) {
1031 const auto start = std::chrono::steady_clock::now();
1032 while (true) {
1033 // TODO: Any chance to make this cleaner?
1034 std::vector<uint8_t> value = doGet(initKey_);
1035 auto buf = reinterpret_cast<const char*>(value.data());
1036 auto len = value.size();
1037 int numWorkersCompleted = std::stoi(std::string(buf, len));
1038 if (numWorkersCompleted >= *numWorkers_) {
1039 break;
1040 }
1041 const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
1042 std::chrono::steady_clock::now() - start);
1043 if (timeout_ != kNoTimeout && elapsed > timeout_) {
1044 break;
1045 }
1046 /* sleep override */
1047 std::this_thread::sleep_for(std::chrono::milliseconds(10));
1048 }
1049 }
1050}
1051
1052void TCPStore::set(const std::string& key, const std::vector<uint8_t>& data) {
1053 const std::lock_guard<std::mutex> lock(activeOpLock_);
1054 client_->sendCommandForKey(detail::QueryType::SET, keyPrefix_ + key);
1055 client_->sendBytes(data);
1056}
1057
1058std::vector<uint8_t> TCPStore::compareSet(
1059 const std::string& key,
1060 const std::vector<uint8_t>& expectedValue,
1061 const std::vector<uint8_t>& desiredValue) {
1062 const std::lock_guard<std::mutex> lock(activeOpLock_);
1063 client_->sendCommandForKey(detail::QueryType::COMPARE_SET, keyPrefix_ + key);
1064 client_->sendBytes(expectedValue);
1065 client_->sendBytes(desiredValue);
1066
1067 return client_->receiveBits();
1068}
1069
1070std::vector<uint8_t> TCPStore::get(const std::string& key) {
1071 const std::lock_guard<std::mutex> lock(activeOpLock_);
1072 return doGet(keyPrefix_ + key);
1073}
1074
1075std::vector<uint8_t> TCPStore::doGet(const std::string& key) {
1076 doWait(key, timeout_);
1077 client_->sendCommandForKey(detail::QueryType::GET, key);
1078 return client_->receiveBits();
1079}
1080
1081int64_t TCPStore::add(const std::string& key, int64_t value) {
1082 const std::lock_guard<std::mutex> lock(activeOpLock_);
1083 return incrementValueBy(keyPrefix_ + key, value);
1084}
1085
1086bool TCPStore::deleteKey(const std::string& key) {
1087 const std::lock_guard<std::mutex> lock(activeOpLock_);
1088 client_->sendCommandForKey(detail::QueryType::DELETE_KEY, keyPrefix_ + key);
1089 auto numDeleted = client_->receiveValue<std::int64_t>();
1090 return numDeleted == 1;
1091}
1092
1093void TCPStore::watchKey(const std::string& key, WatchKeyCallback callback) {
1094 const std::lock_guard<std::mutex> lock(activeOpLock_);
1095 callbackClient_->setCallback(keyPrefix_ + key, callback);
1096}
1097
1098int64_t TCPStore::incrementValueBy(const std::string& key, int64_t delta) {
1099 client_->sendCommandForKey(detail::QueryType::ADD, key);
1100 client_->sendValue<std::int64_t>(delta);
1101 return client_->receiveValue<std::int64_t>();
1102}
1103
1104int64_t TCPStore::getNumKeys() {
1105 const std::lock_guard<std::mutex> lock(activeOpLock_);
1106 client_->sendCommand(detail::QueryType::GETNUMKEYS);
1107 return client_->receiveValue<std::int64_t>();
1108}
1109
1110bool TCPStore::check(const std::vector<std::string>& keys) {
1111 const std::lock_guard<std::mutex> lock(activeOpLock_);
1112 std::vector<std::string> prefixedKeys{};
1113 prefixedKeys.reserve(keys.size());
1114 for (const std::string& key : keys) {
1115 prefixedKeys.emplace_back(keyPrefix_ + key);
1116 }
1117
1118 client_->sendCommand(detail::QueryType::CHECK);
1119 client_->sendStrings(prefixedKeys);
1120
1121 auto response = client_->receiveValue<detail::CheckResponseType>();
1122 if (response == detail::CheckResponseType::READY) {
1123 return true;
1124 }
1125 if (response == detail::CheckResponseType::NOT_READY) {
1126 return false;
1127 }
1128 TORCH_CHECK(false, "ready or not_ready response expected");
1129}
1130
1131void TCPStore::wait(const std::vector<std::string>& keys) {
1132 wait(keys, timeout_);
1133}
1134
1135void TCPStore::wait(
1136 const std::vector<std::string>& keys,
1137 const std::chrono::milliseconds& timeout) {
1138 const std::lock_guard<std::mutex> lock(activeOpLock_);
1139 std::vector<std::string> prefixedKeys{};
1140 prefixedKeys.reserve(keys.size());
1141 for (const std::string& key : keys) {
1142 prefixedKeys.emplace_back(keyPrefix_ + key);
1143 }
1144
1145 doWait(prefixedKeys, timeout);
1146}
1147
1148void TCPStore::doWait(
1149 c10::ArrayRef<std::string> keys,
1150 std::chrono::milliseconds timeout) {
1151 // TODO: Should we revert to the original timeout at the end of the call?
1152 client_->setTimeout(timeout);
1153
1154 client_->sendCommand(detail::QueryType::WAIT);
1155 client_->sendStrings(keys);
1156
1157 auto response = client_->receiveValue<detail::WaitResponseType>();
1158 if (response != detail::WaitResponseType::STOP_WAITING) {
1159 TORCH_CHECK(false, "Stop_waiting response is expected");
1160 }
1161}
1162
1163} // namespace c10d
1164