1 | #include <c10/util/irange.h> |
2 | #include "StoreTestCommon.hpp" |
3 | |
4 | #include <cstdlib> |
5 | #include <future> |
6 | #include <iostream> |
7 | #include <system_error> |
8 | #include <thread> |
9 | |
10 | #include <gtest/gtest.h> |
11 | |
12 | #include <torch/csrc/distributed/c10d/PrefixStore.hpp> |
13 | #include <torch/csrc/distributed/c10d/TCPStore.hpp> |
14 | |
15 | constexpr int64_t kShortStoreTimeoutMillis = 100; |
16 | constexpr int64_t kStoreCallbackTimeoutMillis = 5000; |
17 | constexpr int defaultTimeout = 20; |
18 | |
19 | c10::intrusive_ptr<c10d::TCPStore> _createServer( |
20 | int numWorkers = 1, |
21 | int timeout = defaultTimeout) { |
22 | return c10::make_intrusive<c10d::TCPStore>( |
23 | "127.0.0.1" , |
24 | c10d::TCPStoreOptions{ |
25 | /* port */ 0, |
26 | /* isServer */ true, |
27 | numWorkers, |
28 | /* waitWorkers */ false, |
29 | /* timeout */ std::chrono::seconds(timeout)}); |
30 | } |
31 | |
32 | // Different ports for different tests. |
33 | void testHelper(const std::string& prefix = "" ) { |
34 | constexpr auto numThreads = 16; |
35 | constexpr auto numWorkers = numThreads + 1; |
36 | |
37 | auto serverTCPStore = _createServer(numWorkers); |
38 | |
39 | auto serverStore = |
40 | c10::make_intrusive<c10d::PrefixStore>(prefix, serverTCPStore); |
41 | // server store |
42 | auto serverThread = std::thread([&serverStore, &serverTCPStore] { |
43 | // Wait for all workers to join. |
44 | serverTCPStore->waitForWorkers(); |
45 | |
46 | // Basic set/get on the server store |
47 | c10d::test::set(*serverStore, "key0" , "value0" ); |
48 | c10d::test::set(*serverStore, "key1" , "value1" ); |
49 | c10d::test::set(*serverStore, "key2" , "value2" ); |
50 | c10d::test::check(*serverStore, "key0" , "value0" ); |
51 | c10d::test::check(*serverStore, "key1" , "value1" ); |
52 | c10d::test::check(*serverStore, "key2" , "value2" ); |
53 | serverStore->add("counter" , 1); |
54 | auto numKeys = serverStore->getNumKeys(); |
55 | // We expect 5 keys since 3 are added above, 'counter' is added by the |
56 | // helper thread, and the init key to coordinate workers. |
57 | EXPECT_EQ(numKeys, 5); |
58 | |
59 | // Check compareSet, does not check return value |
60 | c10d::test::compareSet( |
61 | *serverStore, "key0" , "wrongExpectedValue" , "newValue" ); |
62 | c10d::test::check(*serverStore, "key0" , "value0" ); |
63 | c10d::test::compareSet(*serverStore, "key0" , "value0" , "newValue" ); |
64 | c10d::test::check(*serverStore, "key0" , "newValue" ); |
65 | |
66 | auto delSuccess = serverStore->deleteKey("key0" ); |
67 | // Ensure that the key was successfully deleted |
68 | EXPECT_TRUE(delSuccess); |
69 | auto delFailure = serverStore->deleteKey("badKeyName" ); |
70 | // The key was not in the store so the delete operation should have failed |
71 | // and returned false. |
72 | EXPECT_FALSE(delFailure); |
73 | numKeys = serverStore->getNumKeys(); |
74 | EXPECT_EQ(numKeys, 4); |
75 | auto timeout = std::chrono::milliseconds(kShortStoreTimeoutMillis); |
76 | serverStore->setTimeout(timeout); |
77 | EXPECT_THROW(serverStore->get("key0" ), c10::Error); |
78 | }); |
79 | |
80 | // Hammer on TCPStore |
81 | std::vector<std::thread> threads; |
82 | constexpr auto numIterations = 1000; |
83 | c10d::test::Semaphore sem1, sem2; |
84 | |
85 | c10d::TCPStoreOptions opts{}; |
86 | opts.port = serverTCPStore->getPort(); |
87 | opts.numWorkers = numWorkers; |
88 | |
89 | // Each thread will have a client store to send/recv data |
90 | std::vector<c10::intrusive_ptr<c10d::TCPStore>> clientTCPStores; |
91 | std::vector<c10::intrusive_ptr<c10d::PrefixStore>> clientStores; |
92 | for (const auto i : c10::irange(numThreads)) { |
93 | clientTCPStores.push_back( |
94 | c10::make_intrusive<c10d::TCPStore>("127.0.0.1" , opts)); |
95 | clientStores.push_back( |
96 | c10::make_intrusive<c10d::PrefixStore>(prefix, clientTCPStores[i])); |
97 | } |
98 | |
99 | std::string expectedCounterRes = |
100 | std::to_string(numThreads * numIterations + 1); |
101 | |
102 | for (const auto i : c10::irange(numThreads)) { |
103 | threads.emplace_back( |
104 | std::thread([=, &sem1, &sem2, &clientStores, &expectedCounterRes] { |
105 | for (C10_UNUSED const auto j : c10::irange(numIterations)) { |
106 | clientStores[i]->add("counter" , 1); |
107 | } |
108 | // Let each thread set and get key on its client store |
109 | std::string key = "thread_" + std::to_string(i); |
110 | for (const auto j : c10::irange(numIterations)) { |
111 | std::string val = "thread_val_" + std::to_string(j); |
112 | c10d::test::set(*clientStores[i], key, val); |
113 | c10d::test::check(*clientStores[i], key, val); |
114 | } |
115 | |
116 | sem1.post(); |
117 | sem2.wait(); |
118 | // Check the counter results |
119 | c10d::test::check(*clientStores[i], "counter" , expectedCounterRes); |
120 | // Now check other threads' written data |
121 | for (const auto j : c10::irange(numThreads)) { |
122 | if (j == i) { |
123 | continue; |
124 | } |
125 | std::string key = "thread_" + std::to_string(i); |
126 | std::string val = "thread_val_" + std::to_string(numIterations - 1); |
127 | c10d::test::check(*clientStores[i], key, val); |
128 | } |
129 | })); |
130 | } |
131 | |
132 | sem1.wait(numThreads); |
133 | sem2.post(numThreads); |
134 | |
135 | for (auto& thread : threads) { |
136 | thread.join(); |
137 | } |
138 | |
139 | serverThread.join(); |
140 | |
141 | // Clear the store to test that client disconnect won't shutdown the store |
142 | clientStores.clear(); |
143 | clientTCPStores.clear(); |
144 | |
145 | // Check that the counter has the expected value |
146 | c10d::test::check(*serverStore, "counter" , expectedCounterRes); |
147 | |
148 | // Check that each threads' written data from the main thread |
149 | for (const auto i : c10::irange(numThreads)) { |
150 | std::string key = "thread_" + std::to_string(i); |
151 | std::string val = "thread_val_" + std::to_string(numIterations - 1); |
152 | c10d::test::check(*serverStore, key, val); |
153 | } |
154 | } |
155 | |
156 | void testWatchKeyCallback(const std::string& prefix = "" ) { |
157 | // Callback function increments counter of the total number of callbacks that |
158 | // were run |
159 | std::promise<int> numCallbacksExecutedPromise; |
160 | std::atomic<int> numCallbacksExecuted{0}; |
161 | constexpr int numThreads = 16; |
162 | constexpr int keyChangeOperation = 3; |
163 | c10d::WatchKeyCallback callback = |
164 | [=, &numCallbacksExecuted, &numCallbacksExecutedPromise]( |
165 | c10::optional<std::string> /* unused */, |
166 | c10::optional<std::string> /* unused */) { |
167 | numCallbacksExecuted++; |
168 | if (numCallbacksExecuted == numThreads * keyChangeOperation * 2) { |
169 | numCallbacksExecutedPromise.set_value(numCallbacksExecuted); |
170 | } |
171 | }; |
172 | |
173 | const int numWorkers = numThreads + 1; |
174 | auto serverTCPStore = _createServer(numWorkers); |
175 | auto serverStore = |
176 | c10::make_intrusive<c10d::PrefixStore>(prefix, serverTCPStore); |
177 | |
178 | c10d::TCPStoreOptions opts{}; |
179 | opts.port = serverTCPStore->getPort(); |
180 | opts.numWorkers = numWorkers; |
181 | |
182 | // Each thread will have a client store to send/recv data |
183 | std::vector<c10::intrusive_ptr<c10d::TCPStore>> clientTCPStores; |
184 | std::vector<c10::intrusive_ptr<c10d::PrefixStore>> clientStores; |
185 | for (const auto i : c10::irange(numThreads)) { |
186 | clientTCPStores.push_back( |
187 | c10::make_intrusive<c10d::TCPStore>("127.0.0.1" , opts)); |
188 | clientStores.push_back( |
189 | c10::make_intrusive<c10d::PrefixStore>(prefix, clientTCPStores[i])); |
190 | } |
191 | |
192 | // Start watching key on server and client stores |
193 | std::string internalKey = "internalKey" ; |
194 | std::string internalKeyCount = "internalKeyCount" ; |
195 | for (const auto i : c10::irange(numThreads)) { |
196 | serverStore->watchKey(internalKey + std::to_string(i), callback); |
197 | serverStore->watchKey(internalKeyCount + std::to_string(i), callback); |
198 | clientStores[i]->watchKey(internalKey + std::to_string(i), callback); |
199 | clientStores[i]->watchKey(internalKeyCount + std::to_string(i), callback); |
200 | } |
201 | |
202 | std::vector<std::thread> threads; |
203 | std::atomic<int> keyChangeOperationCount{0}; |
204 | for (const auto i : c10::irange(numThreads)) { |
205 | threads.emplace_back(std::thread([=, |
206 | &clientStores, |
207 | &internalKey, |
208 | &internalKeyCount, |
209 | &keyChangeOperationCount] { |
210 | // Let each thread set and get key on its client store |
211 | std::string key = internalKey + std::to_string(i); |
212 | std::string keyCounter = internalKeyCount + std::to_string(i); |
213 | std::string val = "thread_val_" + std::to_string(i); |
214 | // The set, compareSet, add methods count as key change operations |
215 | c10d::test::set(*clientStores[i], key, val); |
216 | c10d::test::compareSet(*clientStores[i], key, val, "newValue" ); |
217 | clientStores[i]->add(keyCounter, i); |
218 | keyChangeOperationCount += keyChangeOperation * 2; |
219 | c10d::test::check(*clientStores[i], key, "newValue" ); |
220 | c10d::test::check(*clientStores[i], keyCounter, std::to_string(i)); |
221 | })); |
222 | } |
223 | |
224 | // Ensures that internal_key has been "set" and "get" |
225 | for (auto& thread : threads) { |
226 | thread.join(); |
227 | } |
228 | |
229 | std::future<int> numCallbacksExecutedFuture = |
230 | numCallbacksExecutedPromise.get_future(); |
231 | std::chrono::milliseconds span(kStoreCallbackTimeoutMillis); |
232 | if (numCallbacksExecutedFuture.wait_for(span) == std::future_status::timeout) |
233 | TORCH_CHECK(false, "Callback execution timed out." ); |
234 | |
235 | // Check number of callbacks executed equal to number of key change operations |
236 | // Wait for all callbacks to be triggered |
237 | EXPECT_EQ(keyChangeOperationCount, numCallbacksExecutedFuture.get()); |
238 | } |
239 | |
240 | TEST(TCPStoreTest, testHelper) { |
241 | testHelper(); |
242 | } |
243 | |
244 | TEST(TCPStoreTest, testHelperPrefix) { |
245 | testHelper("testPrefix" ); |
246 | } |
247 | |
248 | TEST(TCPStoreTest, testWatchKeyCallback) { |
249 | testWatchKeyCallback(); |
250 | } |
251 | |
252 | TEST(TCPStoreTest, testWatchKeyCallbackWithPrefix) { |
253 | testWatchKeyCallback("testPrefix" ); |
254 | } |
255 | |
256 | // Helper function to create a key on the store, watch it, and run the callback |
257 | void testKeyChangeHelper( |
258 | c10d::Store& store, |
259 | std::string key, |
260 | const c10::optional<std::string>& expectedOldValue, |
261 | const c10::optional<std::string>& expectedNewValue) { |
262 | std::exception_ptr eptr = nullptr; |
263 | std::promise<bool> callbackPromise; |
264 | |
265 | // Test the correctness of new_value and old_value |
266 | c10d::WatchKeyCallback callback = [expectedOldValue, |
267 | expectedNewValue, |
268 | &callbackPromise, |
269 | &eptr]( |
270 | c10::optional<std::string> oldValue, |
271 | c10::optional<std::string> newValue) { |
272 | try { |
273 | EXPECT_EQ(expectedOldValue.value_or("NONE" ), oldValue.value_or("NONE" )); |
274 | EXPECT_EQ(expectedNewValue.value_or("NONE" ), newValue.value_or("NONE" )); |
275 | } catch (...) { |
276 | eptr = std::current_exception(); |
277 | } |
278 | callbackPromise.set_value(true); |
279 | }; |
280 | store.watchKey(key, callback); |
281 | |
282 | // Perform the specified update according to key |
283 | if (key == "testEmptyKeyValue" || key == "testRegularKeyValue" || |
284 | key == "testWatchKeyCreate" ) { |
285 | c10d::test::set(store, key, expectedNewValue.value()); |
286 | } else if (key == "testWatchKeyAdd" ) { |
287 | store.add(key, std::stoi(expectedNewValue.value())); |
288 | } else if (key == "testWatchKeyDelete" ) { |
289 | store.deleteKey(key); |
290 | } |
291 | |
292 | // Test that the callback is fired and the expected values are correct |
293 | std::future<bool> callbackFuture = callbackPromise.get_future(); |
294 | std::chrono::milliseconds span(kStoreCallbackTimeoutMillis); |
295 | if (callbackFuture.wait_for(span) == std::future_status::timeout) |
296 | TORCH_CHECK(false, "Callback execution timed out." ); |
297 | |
298 | // Any exceptions raised from asserts should be rethrown |
299 | if (eptr) |
300 | std::rethrow_exception(eptr); |
301 | } |
302 | |
303 | TEST(TCPStoreTest, testKeyEmptyUpdate) { |
304 | auto store = _createServer(); |
305 | |
306 | std::string key = "testEmptyKeyValue" ; |
307 | c10d::test::set(*store, key, "" ); |
308 | store->get(key); |
309 | testKeyChangeHelper(*store, key, "" , "2" ); |
310 | } |
311 | |
312 | TEST(TCPStoreTest, testKeyUpdate) { |
313 | auto store = _createServer(); |
314 | |
315 | std::string key = "testRegularKeyValue" ; |
316 | c10d::test::set(*store, key, "1" ); |
317 | store->get(key); |
318 | testKeyChangeHelper(*store, key, "1" , "2" ); |
319 | } |
320 | |
321 | TEST(TCPStoreTest, testKeyCreate) { |
322 | auto store = _createServer(); |
323 | |
324 | std::string key = "testWatchKeyCreate" ; |
325 | testKeyChangeHelper(*store, key, c10::nullopt, "2" ); |
326 | } |
327 | |
328 | TEST(TCPStoreTest, testKeyAdd) { |
329 | auto store = _createServer(); |
330 | |
331 | std::string key = "testWatchKeyAdd" ; |
332 | testKeyChangeHelper(*store, key, c10::nullopt, "2" ); |
333 | } |
334 | |
335 | TEST(TCPStoreTest, testKeyDelete) { |
336 | auto store = _createServer(); |
337 | |
338 | std::string key = "testWatchKeyDelete" ; |
339 | c10d::test::set(*store, key, "1" ); |
340 | store->get(key); |
341 | testKeyChangeHelper(*store, key, "1" , c10::nullopt); |
342 | } |
343 | |
344 | TEST(TCPStoreTest, testCleanShutdown) { |
345 | int numWorkers = 2; |
346 | |
347 | auto serverTCPStore = std::make_unique<c10d::TCPStore>( |
348 | "127.0.0.1" , |
349 | 0, |
350 | numWorkers, |
351 | true, |
352 | std::chrono::seconds(defaultTimeout), |
353 | /* wait */ false); |
354 | c10d::test::set(*serverTCPStore, "key" , "val" ); |
355 | |
356 | auto clientTCPStore = c10::make_intrusive<c10d::TCPStore>( |
357 | "127.0.0.1" , |
358 | c10d::TCPStoreOptions{ |
359 | /* port */ serverTCPStore->getPort(), |
360 | /* isServer */ false, |
361 | numWorkers, |
362 | /* waitWorkers */ false, |
363 | /* timeout */ std::chrono::seconds(defaultTimeout)}); |
364 | clientTCPStore->get("key" ); |
365 | |
366 | auto clientThread = std::thread([&clientTCPStore] { |
367 | EXPECT_THROW(clientTCPStore->get("invalid_key" ), std::system_error); |
368 | }); |
369 | |
370 | // start server shutdown during a client request |
371 | serverTCPStore = nullptr; |
372 | |
373 | clientThread.join(); |
374 | } |
375 | |
376 | TEST(TCPStoreTest, testMultiTenantStores) { |
377 | c10d::TCPStoreOptions opts{}; |
378 | opts.isServer = true; |
379 | opts.multiTenant = true; |
380 | |
381 | // Construct two server stores on the same port. |
382 | auto store1 = c10::make_intrusive<c10d::TCPStore>("localhost" , opts); |
383 | auto store2 = c10::make_intrusive<c10d::TCPStore>("localhost" , opts); |
384 | |
385 | // Assert that the two stores share the same server. |
386 | c10d::test::set(*store1, "key0" , "value0" ); |
387 | c10d::test::check(*store2, "key0" , "value0" ); |
388 | |
389 | // Dispose the second instance and assert that the server is still alive. |
390 | store2.reset(); |
391 | |
392 | c10d::test::set(*store1, "key0" , "value0" ); |
393 | c10d::test::check(*store1, "key0" , "value0" ); |
394 | } |
395 | |