1#pragma once
2
3#include <cstddef>
4#include <cstdint>
5#include <memory>
6
7#include <torch/csrc/distributed/c10d/Store.hpp>
8
9namespace c10d {
10namespace detail {
11
12class TCPServer;
13
14class TCPClient;
15
16class TCPCallbackClient;
17
18struct SocketAddress {
19 std::string host{};
20 std::uint16_t port{};
21};
22
23} // namespace detail
24
25struct TCPStoreOptions {
26 static constexpr std::uint16_t kDefaultPort = 29500;
27
28 std::uint16_t port = kDefaultPort;
29 bool isServer = false;
30 c10::optional<std::size_t> numWorkers = c10::nullopt;
31 bool waitWorkers = true;
32 std::chrono::milliseconds timeout = Store::kDefaultTimeout;
33
34 // A boolean value indicating whether multiple store instances can be
35 // initialized with the same host:port pair.
36 bool multiTenant = false;
37};
38
39class TORCH_API TCPStore : public Store {
40 public:
41 explicit TCPStore(std::string host, const TCPStoreOptions& opts = {});
42
43 [[deprecated("Use TCPStore(host, opts) instead.")]] explicit TCPStore(
44 const std::string& masterAddr,
45 std::uint16_t masterPort,
46 c10::optional<int> numWorkers = c10::nullopt,
47 bool isServer = false,
48 const std::chrono::milliseconds& timeout = kDefaultTimeout,
49 bool waitWorkers = true);
50
51 ~TCPStore() override;
52
53 void set(const std::string& key, const std::vector<uint8_t>& value) override;
54
55 std::vector<uint8_t> compareSet(
56 const std::string& key,
57 const std::vector<uint8_t>& expectedValue,
58 const std::vector<uint8_t>& desiredValue) override;
59
60 std::vector<uint8_t> get(const std::string& key) override;
61
62 int64_t add(const std::string& key, int64_t value) override;
63
64 bool deleteKey(const std::string& key) override;
65
66 // NOTE: calling other TCPStore APIs inside the callback is NOT threadsafe
67 // watchKey() is a blocking operation. It will register the socket on
68 // TCPStoreMasterDaemon and the callback on TCPStoreWorkerDaemon. It will
69 // return once it has verified the callback is registered on both background
70 // threads. Only one thread can call watchKey() at a time.
71 void watchKey(const std::string& key, WatchKeyCallback callback) override;
72
73 bool check(const std::vector<std::string>& keys) override;
74
75 int64_t getNumKeys() override;
76
77 void wait(const std::vector<std::string>& keys) override;
78
79 void wait(
80 const std::vector<std::string>& keys,
81 const std::chrono::milliseconds& timeout) override;
82
83 // Waits for all workers to join.
84 void waitForWorkers();
85
86 // Returns the hostname used by the TCPStore.
87 const std::string& getHost() const noexcept {
88 return addr_.host;
89 }
90
91 // Returns the port used by the TCPStore.
92 std::uint16_t getPort() const noexcept {
93 return addr_.port;
94 }
95
96 private:
97 int64_t incrementValueBy(const std::string& key, int64_t delta);
98
99 std::vector<uint8_t> doGet(const std::string& key);
100
101 void doWait(
102 c10::ArrayRef<std::string> keys,
103 std::chrono::milliseconds timeout);
104
105 detail::SocketAddress addr_;
106 std::shared_ptr<detail::TCPServer> server_;
107 std::unique_ptr<detail::TCPClient> client_;
108 std::unique_ptr<detail::TCPCallbackClient> callbackClient_;
109 c10::optional<std::size_t> numWorkers_;
110
111 const std::string initKey_ = "init/";
112 const std::string keyPrefix_ = "/";
113 std::mutex activeOpLock_;
114};
115
116} // namespace c10d
117