1 | #pragma once |
2 | |
3 | #include <cstddef> |
4 | #include <cstdint> |
5 | #include <memory> |
6 | |
7 | #include <torch/csrc/distributed/c10d/Store.hpp> |
8 | |
9 | namespace c10d { |
10 | namespace detail { |
11 | |
12 | class TCPServer; |
13 | |
14 | class TCPClient; |
15 | |
16 | class TCPCallbackClient; |
17 | |
18 | struct SocketAddress { |
19 | std::string host{}; |
20 | std::uint16_t port{}; |
21 | }; |
22 | |
23 | } // namespace detail |
24 | |
25 | struct TCPStoreOptions { |
26 | static constexpr std::uint16_t kDefaultPort = 29500; |
27 | |
28 | std::uint16_t port = kDefaultPort; |
29 | bool isServer = false; |
30 | c10::optional<std::size_t> numWorkers = c10::nullopt; |
31 | bool waitWorkers = true; |
32 | std::chrono::milliseconds timeout = Store::kDefaultTimeout; |
33 | |
34 | // A boolean value indicating whether multiple store instances can be |
35 | // initialized with the same host:port pair. |
36 | bool multiTenant = false; |
37 | }; |
38 | |
39 | class TORCH_API TCPStore : public Store { |
40 | public: |
41 | explicit TCPStore(std::string host, const TCPStoreOptions& opts = {}); |
42 | |
43 | [[deprecated("Use TCPStore(host, opts) instead." )]] explicit TCPStore( |
44 | const std::string& masterAddr, |
45 | std::uint16_t masterPort, |
46 | c10::optional<int> numWorkers = c10::nullopt, |
47 | bool isServer = false, |
48 | const std::chrono::milliseconds& timeout = kDefaultTimeout, |
49 | bool waitWorkers = true); |
50 | |
51 | ~TCPStore() override; |
52 | |
53 | void set(const std::string& key, const std::vector<uint8_t>& value) override; |
54 | |
55 | std::vector<uint8_t> compareSet( |
56 | const std::string& key, |
57 | const std::vector<uint8_t>& expectedValue, |
58 | const std::vector<uint8_t>& desiredValue) override; |
59 | |
60 | std::vector<uint8_t> get(const std::string& key) override; |
61 | |
62 | int64_t add(const std::string& key, int64_t value) override; |
63 | |
64 | bool deleteKey(const std::string& key) override; |
65 | |
66 | // NOTE: calling other TCPStore APIs inside the callback is NOT threadsafe |
67 | // watchKey() is a blocking operation. It will register the socket on |
68 | // TCPStoreMasterDaemon and the callback on TCPStoreWorkerDaemon. It will |
69 | // return once it has verified the callback is registered on both background |
70 | // threads. Only one thread can call watchKey() at a time. |
71 | void watchKey(const std::string& key, WatchKeyCallback callback) override; |
72 | |
73 | bool check(const std::vector<std::string>& keys) override; |
74 | |
75 | int64_t getNumKeys() override; |
76 | |
77 | void wait(const std::vector<std::string>& keys) override; |
78 | |
79 | void wait( |
80 | const std::vector<std::string>& keys, |
81 | const std::chrono::milliseconds& timeout) override; |
82 | |
83 | // Waits for all workers to join. |
84 | void waitForWorkers(); |
85 | |
86 | // Returns the hostname used by the TCPStore. |
87 | const std::string& getHost() const noexcept { |
88 | return addr_.host; |
89 | } |
90 | |
91 | // Returns the port used by the TCPStore. |
92 | std::uint16_t getPort() const noexcept { |
93 | return addr_.port; |
94 | } |
95 | |
96 | private: |
97 | int64_t incrementValueBy(const std::string& key, int64_t delta); |
98 | |
99 | std::vector<uint8_t> doGet(const std::string& key); |
100 | |
101 | void doWait( |
102 | c10::ArrayRef<std::string> keys, |
103 | std::chrono::milliseconds timeout); |
104 | |
105 | detail::SocketAddress addr_; |
106 | std::shared_ptr<detail::TCPServer> server_; |
107 | std::unique_ptr<detail::TCPClient> client_; |
108 | std::unique_ptr<detail::TCPCallbackClient> callbackClient_; |
109 | c10::optional<std::size_t> numWorkers_; |
110 | |
111 | const std::string initKey_ = "init/" ; |
112 | const std::string keyPrefix_ = "/" ; |
113 | std::mutex activeOpLock_; |
114 | }; |
115 | |
116 | } // namespace c10d |
117 | |