1#include <c10/util/irange.h>
2#include "StoreTestCommon.hpp"
3
4#include <cstdlib>
5#include <future>
6#include <iostream>
7#include <system_error>
8#include <thread>
9
10#include <gtest/gtest.h>
11
12#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
13#include <torch/csrc/distributed/c10d/TCPStore.hpp>
14
15constexpr int64_t kShortStoreTimeoutMillis = 100;
16constexpr int64_t kStoreCallbackTimeoutMillis = 5000;
17constexpr int defaultTimeout = 20;
18
19c10::intrusive_ptr<c10d::TCPStore> _createServer(
20 int numWorkers = 1,
21 int timeout = defaultTimeout) {
22 return c10::make_intrusive<c10d::TCPStore>(
23 "127.0.0.1",
24 c10d::TCPStoreOptions{
25 /* port */ 0,
26 /* isServer */ true,
27 numWorkers,
28 /* waitWorkers */ false,
29 /* timeout */ std::chrono::seconds(timeout)});
30}
31
32// Different ports for different tests.
33void testHelper(const std::string& prefix = "") {
34 constexpr auto numThreads = 16;
35 constexpr auto numWorkers = numThreads + 1;
36
37 auto serverTCPStore = _createServer(numWorkers);
38
39 auto serverStore =
40 c10::make_intrusive<c10d::PrefixStore>(prefix, serverTCPStore);
41 // server store
42 auto serverThread = std::thread([&serverStore, &serverTCPStore] {
43 // Wait for all workers to join.
44 serverTCPStore->waitForWorkers();
45
46 // Basic set/get on the server store
47 c10d::test::set(*serverStore, "key0", "value0");
48 c10d::test::set(*serverStore, "key1", "value1");
49 c10d::test::set(*serverStore, "key2", "value2");
50 c10d::test::check(*serverStore, "key0", "value0");
51 c10d::test::check(*serverStore, "key1", "value1");
52 c10d::test::check(*serverStore, "key2", "value2");
53 serverStore->add("counter", 1);
54 auto numKeys = serverStore->getNumKeys();
55 // We expect 5 keys since 3 are added above, 'counter' is added by the
56 // helper thread, and the init key to coordinate workers.
57 EXPECT_EQ(numKeys, 5);
58
59 // Check compareSet, does not check return value
60 c10d::test::compareSet(
61 *serverStore, "key0", "wrongExpectedValue", "newValue");
62 c10d::test::check(*serverStore, "key0", "value0");
63 c10d::test::compareSet(*serverStore, "key0", "value0", "newValue");
64 c10d::test::check(*serverStore, "key0", "newValue");
65
66 auto delSuccess = serverStore->deleteKey("key0");
67 // Ensure that the key was successfully deleted
68 EXPECT_TRUE(delSuccess);
69 auto delFailure = serverStore->deleteKey("badKeyName");
70 // The key was not in the store so the delete operation should have failed
71 // and returned false.
72 EXPECT_FALSE(delFailure);
73 numKeys = serverStore->getNumKeys();
74 EXPECT_EQ(numKeys, 4);
75 auto timeout = std::chrono::milliseconds(kShortStoreTimeoutMillis);
76 serverStore->setTimeout(timeout);
77 EXPECT_THROW(serverStore->get("key0"), c10::Error);
78 });
79
80 // Hammer on TCPStore
81 std::vector<std::thread> threads;
82 constexpr auto numIterations = 1000;
83 c10d::test::Semaphore sem1, sem2;
84
85 c10d::TCPStoreOptions opts{};
86 opts.port = serverTCPStore->getPort();
87 opts.numWorkers = numWorkers;
88
89 // Each thread will have a client store to send/recv data
90 std::vector<c10::intrusive_ptr<c10d::TCPStore>> clientTCPStores;
91 std::vector<c10::intrusive_ptr<c10d::PrefixStore>> clientStores;
92 for (const auto i : c10::irange(numThreads)) {
93 clientTCPStores.push_back(
94 c10::make_intrusive<c10d::TCPStore>("127.0.0.1", opts));
95 clientStores.push_back(
96 c10::make_intrusive<c10d::PrefixStore>(prefix, clientTCPStores[i]));
97 }
98
99 std::string expectedCounterRes =
100 std::to_string(numThreads * numIterations + 1);
101
102 for (const auto i : c10::irange(numThreads)) {
103 threads.emplace_back(
104 std::thread([=, &sem1, &sem2, &clientStores, &expectedCounterRes] {
105 for (C10_UNUSED const auto j : c10::irange(numIterations)) {
106 clientStores[i]->add("counter", 1);
107 }
108 // Let each thread set and get key on its client store
109 std::string key = "thread_" + std::to_string(i);
110 for (const auto j : c10::irange(numIterations)) {
111 std::string val = "thread_val_" + std::to_string(j);
112 c10d::test::set(*clientStores[i], key, val);
113 c10d::test::check(*clientStores[i], key, val);
114 }
115
116 sem1.post();
117 sem2.wait();
118 // Check the counter results
119 c10d::test::check(*clientStores[i], "counter", expectedCounterRes);
120 // Now check other threads' written data
121 for (const auto j : c10::irange(numThreads)) {
122 if (j == i) {
123 continue;
124 }
125 std::string key = "thread_" + std::to_string(i);
126 std::string val = "thread_val_" + std::to_string(numIterations - 1);
127 c10d::test::check(*clientStores[i], key, val);
128 }
129 }));
130 }
131
132 sem1.wait(numThreads);
133 sem2.post(numThreads);
134
135 for (auto& thread : threads) {
136 thread.join();
137 }
138
139 serverThread.join();
140
141 // Clear the store to test that client disconnect won't shutdown the store
142 clientStores.clear();
143 clientTCPStores.clear();
144
145 // Check that the counter has the expected value
146 c10d::test::check(*serverStore, "counter", expectedCounterRes);
147
148 // Check that each threads' written data from the main thread
149 for (const auto i : c10::irange(numThreads)) {
150 std::string key = "thread_" + std::to_string(i);
151 std::string val = "thread_val_" + std::to_string(numIterations - 1);
152 c10d::test::check(*serverStore, key, val);
153 }
154}
155
156void testWatchKeyCallback(const std::string& prefix = "") {
157 // Callback function increments counter of the total number of callbacks that
158 // were run
159 std::promise<int> numCallbacksExecutedPromise;
160 std::atomic<int> numCallbacksExecuted{0};
161 constexpr int numThreads = 16;
162 constexpr int keyChangeOperation = 3;
163 c10d::WatchKeyCallback callback =
164 [=, &numCallbacksExecuted, &numCallbacksExecutedPromise](
165 c10::optional<std::string> /* unused */,
166 c10::optional<std::string> /* unused */) {
167 numCallbacksExecuted++;
168 if (numCallbacksExecuted == numThreads * keyChangeOperation * 2) {
169 numCallbacksExecutedPromise.set_value(numCallbacksExecuted);
170 }
171 };
172
173 const int numWorkers = numThreads + 1;
174 auto serverTCPStore = _createServer(numWorkers);
175 auto serverStore =
176 c10::make_intrusive<c10d::PrefixStore>(prefix, serverTCPStore);
177
178 c10d::TCPStoreOptions opts{};
179 opts.port = serverTCPStore->getPort();
180 opts.numWorkers = numWorkers;
181
182 // Each thread will have a client store to send/recv data
183 std::vector<c10::intrusive_ptr<c10d::TCPStore>> clientTCPStores;
184 std::vector<c10::intrusive_ptr<c10d::PrefixStore>> clientStores;
185 for (const auto i : c10::irange(numThreads)) {
186 clientTCPStores.push_back(
187 c10::make_intrusive<c10d::TCPStore>("127.0.0.1", opts));
188 clientStores.push_back(
189 c10::make_intrusive<c10d::PrefixStore>(prefix, clientTCPStores[i]));
190 }
191
192 // Start watching key on server and client stores
193 std::string internalKey = "internalKey";
194 std::string internalKeyCount = "internalKeyCount";
195 for (const auto i : c10::irange(numThreads)) {
196 serverStore->watchKey(internalKey + std::to_string(i), callback);
197 serverStore->watchKey(internalKeyCount + std::to_string(i), callback);
198 clientStores[i]->watchKey(internalKey + std::to_string(i), callback);
199 clientStores[i]->watchKey(internalKeyCount + std::to_string(i), callback);
200 }
201
202 std::vector<std::thread> threads;
203 std::atomic<int> keyChangeOperationCount{0};
204 for (const auto i : c10::irange(numThreads)) {
205 threads.emplace_back(std::thread([=,
206 &clientStores,
207 &internalKey,
208 &internalKeyCount,
209 &keyChangeOperationCount] {
210 // Let each thread set and get key on its client store
211 std::string key = internalKey + std::to_string(i);
212 std::string keyCounter = internalKeyCount + std::to_string(i);
213 std::string val = "thread_val_" + std::to_string(i);
214 // The set, compareSet, add methods count as key change operations
215 c10d::test::set(*clientStores[i], key, val);
216 c10d::test::compareSet(*clientStores[i], key, val, "newValue");
217 clientStores[i]->add(keyCounter, i);
218 keyChangeOperationCount += keyChangeOperation * 2;
219 c10d::test::check(*clientStores[i], key, "newValue");
220 c10d::test::check(*clientStores[i], keyCounter, std::to_string(i));
221 }));
222 }
223
224 // Ensures that internal_key has been "set" and "get"
225 for (auto& thread : threads) {
226 thread.join();
227 }
228
229 std::future<int> numCallbacksExecutedFuture =
230 numCallbacksExecutedPromise.get_future();
231 std::chrono::milliseconds span(kStoreCallbackTimeoutMillis);
232 if (numCallbacksExecutedFuture.wait_for(span) == std::future_status::timeout)
233 TORCH_CHECK(false, "Callback execution timed out.");
234
235 // Check number of callbacks executed equal to number of key change operations
236 // Wait for all callbacks to be triggered
237 EXPECT_EQ(keyChangeOperationCount, numCallbacksExecutedFuture.get());
238}
239
240TEST(TCPStoreTest, testHelper) {
241 testHelper();
242}
243
244TEST(TCPStoreTest, testHelperPrefix) {
245 testHelper("testPrefix");
246}
247
248TEST(TCPStoreTest, testWatchKeyCallback) {
249 testWatchKeyCallback();
250}
251
252TEST(TCPStoreTest, testWatchKeyCallbackWithPrefix) {
253 testWatchKeyCallback("testPrefix");
254}
255
256// Helper function to create a key on the store, watch it, and run the callback
257void testKeyChangeHelper(
258 c10d::Store& store,
259 std::string key,
260 const c10::optional<std::string>& expectedOldValue,
261 const c10::optional<std::string>& expectedNewValue) {
262 std::exception_ptr eptr = nullptr;
263 std::promise<bool> callbackPromise;
264
265 // Test the correctness of new_value and old_value
266 c10d::WatchKeyCallback callback = [expectedOldValue,
267 expectedNewValue,
268 &callbackPromise,
269 &eptr](
270 c10::optional<std::string> oldValue,
271 c10::optional<std::string> newValue) {
272 try {
273 EXPECT_EQ(expectedOldValue.value_or("NONE"), oldValue.value_or("NONE"));
274 EXPECT_EQ(expectedNewValue.value_or("NONE"), newValue.value_or("NONE"));
275 } catch (...) {
276 eptr = std::current_exception();
277 }
278 callbackPromise.set_value(true);
279 };
280 store.watchKey(key, callback);
281
282 // Perform the specified update according to key
283 if (key == "testEmptyKeyValue" || key == "testRegularKeyValue" ||
284 key == "testWatchKeyCreate") {
285 c10d::test::set(store, key, expectedNewValue.value());
286 } else if (key == "testWatchKeyAdd") {
287 store.add(key, std::stoi(expectedNewValue.value()));
288 } else if (key == "testWatchKeyDelete") {
289 store.deleteKey(key);
290 }
291
292 // Test that the callback is fired and the expected values are correct
293 std::future<bool> callbackFuture = callbackPromise.get_future();
294 std::chrono::milliseconds span(kStoreCallbackTimeoutMillis);
295 if (callbackFuture.wait_for(span) == std::future_status::timeout)
296 TORCH_CHECK(false, "Callback execution timed out.");
297
298 // Any exceptions raised from asserts should be rethrown
299 if (eptr)
300 std::rethrow_exception(eptr);
301}
302
303TEST(TCPStoreTest, testKeyEmptyUpdate) {
304 auto store = _createServer();
305
306 std::string key = "testEmptyKeyValue";
307 c10d::test::set(*store, key, "");
308 store->get(key);
309 testKeyChangeHelper(*store, key, "", "2");
310}
311
312TEST(TCPStoreTest, testKeyUpdate) {
313 auto store = _createServer();
314
315 std::string key = "testRegularKeyValue";
316 c10d::test::set(*store, key, "1");
317 store->get(key);
318 testKeyChangeHelper(*store, key, "1", "2");
319}
320
321TEST(TCPStoreTest, testKeyCreate) {
322 auto store = _createServer();
323
324 std::string key = "testWatchKeyCreate";
325 testKeyChangeHelper(*store, key, c10::nullopt, "2");
326}
327
328TEST(TCPStoreTest, testKeyAdd) {
329 auto store = _createServer();
330
331 std::string key = "testWatchKeyAdd";
332 testKeyChangeHelper(*store, key, c10::nullopt, "2");
333}
334
335TEST(TCPStoreTest, testKeyDelete) {
336 auto store = _createServer();
337
338 std::string key = "testWatchKeyDelete";
339 c10d::test::set(*store, key, "1");
340 store->get(key);
341 testKeyChangeHelper(*store, key, "1", c10::nullopt);
342}
343
344TEST(TCPStoreTest, testCleanShutdown) {
345 int numWorkers = 2;
346
347 auto serverTCPStore = std::make_unique<c10d::TCPStore>(
348 "127.0.0.1",
349 0,
350 numWorkers,
351 true,
352 std::chrono::seconds(defaultTimeout),
353 /* wait */ false);
354 c10d::test::set(*serverTCPStore, "key", "val");
355
356 auto clientTCPStore = c10::make_intrusive<c10d::TCPStore>(
357 "127.0.0.1",
358 c10d::TCPStoreOptions{
359 /* port */ serverTCPStore->getPort(),
360 /* isServer */ false,
361 numWorkers,
362 /* waitWorkers */ false,
363 /* timeout */ std::chrono::seconds(defaultTimeout)});
364 clientTCPStore->get("key");
365
366 auto clientThread = std::thread([&clientTCPStore] {
367 EXPECT_THROW(clientTCPStore->get("invalid_key"), std::system_error);
368 });
369
370 // start server shutdown during a client request
371 serverTCPStore = nullptr;
372
373 clientThread.join();
374}
375
376TEST(TCPStoreTest, testMultiTenantStores) {
377 c10d::TCPStoreOptions opts{};
378 opts.isServer = true;
379 opts.multiTenant = true;
380
381 // Construct two server stores on the same port.
382 auto store1 = c10::make_intrusive<c10d::TCPStore>("localhost", opts);
383 auto store2 = c10::make_intrusive<c10d::TCPStore>("localhost", opts);
384
385 // Assert that the two stores share the same server.
386 c10d::test::set(*store1, "key0", "value0");
387 c10d::test::check(*store2, "key0", "value0");
388
389 // Dispose the second instance and assert that the server is still alive.
390 store2.reset();
391
392 c10d::test::set(*store1, "key0", "value0");
393 c10d::test::check(*store1, "key0", "value0");
394}
395