1 | #include <c10/util/irange.h> |
2 | #include "StoreTestCommon.hpp" |
3 | |
4 | #include <unistd.h> |
5 | |
6 | #include <iostream> |
7 | #include <thread> |
8 | |
9 | #include <torch/csrc/distributed/c10d/HashStore.hpp> |
10 | #include <torch/csrc/distributed/c10d/PrefixStore.hpp> |
11 | |
12 | constexpr int64_t kShortStoreTimeoutMillis = 100; |
13 | |
14 | void testGetSet(std::string prefix = "" ) { |
15 | // Basic set/get |
16 | { |
17 | auto hashStore = c10::make_intrusive<c10d::HashStore>(); |
18 | c10d::PrefixStore store(prefix, hashStore); |
19 | c10d::test::set(store, "key0" , "value0" ); |
20 | c10d::test::set(store, "key1" , "value1" ); |
21 | c10d::test::set(store, "key2" , "value2" ); |
22 | c10d::test::check(store, "key0" , "value0" ); |
23 | c10d::test::check(store, "key1" , "value1" ); |
24 | c10d::test::check(store, "key2" , "value2" ); |
25 | |
26 | // Check compareSet, does not check return value |
27 | c10d::test::compareSet(store, "key0" , "wrongExpectedValue" , "newValue" ); |
28 | c10d::test::check(store, "key0" , "value0" ); |
29 | c10d::test::compareSet(store, "key0" , "value0" , "newValue" ); |
30 | c10d::test::check(store, "key0" , "newValue" ); |
31 | |
32 | auto numKeys = store.getNumKeys(); |
33 | EXPECT_EQ(numKeys, 3); |
34 | auto delSuccess = store.deleteKey("key0" ); |
35 | EXPECT_TRUE(delSuccess); |
36 | numKeys = store.getNumKeys(); |
37 | EXPECT_EQ(numKeys, 2); |
38 | auto delFailure = store.deleteKey("badKeyName" ); |
39 | EXPECT_FALSE(delFailure); |
40 | auto timeout = std::chrono::milliseconds(kShortStoreTimeoutMillis); |
41 | store.setTimeout(timeout); |
42 | EXPECT_THROW(store.get("key0" ), std::runtime_error); |
43 | } |
44 | |
45 | // get() waits up to timeout_. |
46 | { |
47 | auto hashStore = c10::make_intrusive<c10d::HashStore>(); |
48 | c10d::PrefixStore store(prefix, hashStore); |
49 | std::thread th([&]() { c10d::test::set(store, "key0" , "value0" ); }); |
50 | c10d::test::check(store, "key0" , "value0" ); |
51 | th.join(); |
52 | } |
53 | } |
54 | |
55 | void stressTestStore(std::string prefix = "" ) { |
56 | // Hammer on HashStore::add |
57 | const auto numThreads = 4; |
58 | const auto numIterations = 100; |
59 | |
60 | std::vector<std::thread> threads; |
61 | c10d::test::Semaphore sem1, sem2; |
62 | auto hashStore = c10::make_intrusive<c10d::HashStore>(); |
63 | c10d::PrefixStore store(prefix, hashStore); |
64 | |
65 | for (C10_UNUSED const auto i : c10::irange(numThreads)) { |
66 | threads.emplace_back(std::thread([&] { |
67 | sem1.post(); |
68 | sem2.wait(); |
69 | for (C10_UNUSED const auto j : c10::irange(numIterations)) { |
70 | store.add("counter" , 1); |
71 | } |
72 | })); |
73 | } |
74 | |
75 | sem1.wait(numThreads); |
76 | sem2.post(numThreads); |
77 | |
78 | for (auto& thread : threads) { |
79 | thread.join(); |
80 | } |
81 | std::string expected = std::to_string(numThreads * numIterations); |
82 | c10d::test::check(store, "counter" , expected); |
83 | } |
84 | |
85 | TEST(HashStoreTest, testGetAndSet) { |
86 | testGetSet(); |
87 | } |
88 | |
89 | TEST(HashStoreTest, testGetAndSetWithPrefix) { |
90 | testGetSet("testPrefix" ); |
91 | } |
92 | |
93 | TEST(HashStoreTest, testStressStore) { |
94 | stressTestStore(); |
95 | } |
96 | |
97 | TEST(HashStoreTest, testStressStoreWithPrefix) { |
98 | stressTestStore("testPrefix" ); |
99 | } |
100 | |